Skip to content

Commit 4046e66

Browse files
authored
examples: only use keep_linebreaks when reading TXT files (huggingface#13320)
* examples: only use keep_linebreaks when reading TXT files for all CLM examples * examples: only use keep_linebreaks when reading TXT files for all CLM examples * examples: only use keep_linebreaks when reading TXT files for all CLM examples
1 parent b6f332e commit 4046e66

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

examples/flax/language-modeling/run_clm_flax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class DataTrainingArguments:
157157
metadata={"help": "The number of processes to use for the preprocessing."},
158158
)
159159
keep_linebreaks: bool = field(
160-
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
160+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
161161
)
162162

163163
def __post_init__(self):
@@ -305,29 +305,31 @@ def main():
305305
)
306306
else:
307307
data_files = {}
308+
dataset_args = {}
308309
if data_args.train_file is not None:
309310
data_files["train"] = data_args.train_file
310311
if data_args.validation_file is not None:
311312
data_files["validation"] = data_args.validation_file
312313
extension = data_args.train_file.split(".")[-1]
313314
if extension == "txt":
314315
extension = "text"
315-
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
316+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
317+
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
316318

317319
if "validation" not in dataset.keys():
318320
dataset["validation"] = load_dataset(
319321
extension,
320-
keep_linebreaks=data_args.keep_linebreaks,
321322
data_files=data_files,
322323
split=f"train[:{data_args.validation_split_percentage}%]",
323324
cache_dir=model_args.cache_dir,
325+
**dataset_args,
324326
)
325327
dataset["train"] = load_dataset(
326328
extension,
327-
keep_linebreaks=data_args.keep_linebreaks,
328329
data_files=data_files,
329330
split=f"train[{data_args.validation_split_percentage}%:]",
330331
cache_dir=model_args.cache_dir,
332+
**dataset_args,
331333
)
332334
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
333335
# https://huggingface.co/docs/datasets/loading_datasets.html.

examples/pytorch/language-modeling/run_clm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class DataTrainingArguments:
173173
metadata={"help": "The number of processes to use for the preprocessing."},
174174
)
175175
keep_linebreaks: bool = field(
176-
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
176+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
177177
)
178178

179179
def __post_init__(self):
@@ -269,6 +269,7 @@ def main():
269269
)
270270
else:
271271
data_files = {}
272+
dataset_args = {}
272273
if data_args.train_file is not None:
273274
data_files["train"] = data_args.train_file
274275
if data_args.validation_file is not None:
@@ -280,22 +281,23 @@ def main():
280281
)
281282
if extension == "txt":
282283
extension = "text"
283-
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
284+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
285+
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
284286
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
285287
if "validation" not in raw_datasets.keys():
286288
raw_datasets["validation"] = load_dataset(
287289
extension,
288-
keep_linebreaks=data_args.keep_linebreaks,
289290
data_files=data_files,
290291
split=f"train[:{data_args.validation_split_percentage}%]",
291292
cache_dir=model_args.cache_dir,
293+
**dataset_args,
292294
)
293295
raw_datasets["train"] = load_dataset(
294296
extension,
295-
keep_linebreaks=data_args.keep_linebreaks,
296297
data_files=data_files,
297298
split=f"train[{data_args.validation_split_percentage}%:]",
298299
cache_dir=model_args.cache_dir,
300+
**dataset_args,
299301
)
300302

301303
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at

examples/pytorch/language-modeling/run_clm_no_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def parse_args():
174174
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
175175
)
176176
parser.add_argument(
177-
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using CSV/JSON/TXT files."
177+
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
178178
)
179179

180180
args = parser.parse_args()
@@ -248,27 +248,29 @@ def main():
248248
)
249249
else:
250250
data_files = {}
251+
dataset_args = {}
251252
if args.train_file is not None:
252253
data_files["train"] = args.train_file
253254
if args.validation_file is not None:
254255
data_files["validation"] = args.validation_file
255256
extension = args.train_file.split(".")[-1]
256257
if extension == "txt":
257258
extension = "text"
258-
raw_datasets = load_dataset(extension, data_files=data_files)
259+
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
260+
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
259261
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
260262
if "validation" not in raw_datasets.keys():
261263
raw_datasets["validation"] = load_dataset(
262264
extension,
263-
keep_linebreaks=not args.no_keep_linebreaks,
264265
data_files=data_files,
265266
split=f"train[:{args.validation_split_percentage}%]",
267+
**dataset_args,
266268
)
267269
raw_datasets["train"] = load_dataset(
268270
extension,
269-
keep_linebreaks=not args.no_keep_linebreaks,
270271
data_files=data_files,
271272
split=f"train[{args.validation_split_percentage}%:]",
273+
**dataset_args,
272274
)
273275

274276
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at

examples/tensorflow/language-modeling/run_clm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class DataTrainingArguments:
187187
},
188188
)
189189
keep_linebreaks: bool = field(
190-
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
190+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
191191
)
192192

193193
def __post_init__(self):
@@ -321,14 +321,16 @@ def main():
321321
)
322322
else:
323323
data_files = {}
324+
dataset_args = {}
324325
if data_args.train_file is not None:
325326
data_files["train"] = data_args.train_file
326327
if data_args.validation_file is not None:
327328
data_files["validation"] = data_args.validation_file
328329
extension = data_args.train_file.split(".")[-1]
329330
if extension == "txt":
330331
extension = "text"
331-
raw_datasets = load_dataset(extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files)
332+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
333+
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
332334
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
333335
# https://huggingface.co/docs/datasets/loading_datasets.html.
334336
# endregion

0 commit comments

Comments
 (0)