Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions scripts/lang_adapt/madx_run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,29 @@ def load_data(data_args, model_args):
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
if data_args.dataset_name is not None and data_args.train_file is not None:
# Create a dataset from the provided file
data_files = {}
dataset_args = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = (
data_args.train_file.split(".")[-1]
if data_args.train_file is not None
else data_args.validation_file.split(".")[-1]
)
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)

# Downloading and loading a dataset from the hub.
raw_hf_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
)
raw_datasets['train'] = datasets.concatenate_datasets(raw_datasets['train'], raw_hf_datasets['train']).shuffle(seed=42)

elif data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
Expand Down Expand Up @@ -680,7 +702,7 @@ def load_data(data_args, model_args):
elif data_args.max_eval_samples is not None :
raw_datasets = raw_datasets['train'].train_test_split(test_size = data_args.max_eval_samples)
else:
raw_datasets = raw_datasets['train'].train_test_split(test_size = data.args.validation_split_percentage/100.0)
raw_datasets = raw_datasets['train'].train_test_split(test_size = data_args.validation_split_percentage/100.0)
raw_datasets['validation'] = raw_datasets['test']
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down