Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions fast_bert/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ def get_test_examples(self, filename='test.txt', size=-1):
def get_labels(self, filename='labels.csv'):
"""See base class."""
if self.labels == None:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
if ".xlsx" in filename:
self.labels = list(pd.read_excel(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
else:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
return self.labels

def _create_examples(self, lines, set_type):
Expand All @@ -200,7 +204,7 @@ def read_col_file(self, filename):
'''
read file
return format :
[ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'],
[ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'],
['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ]
'''
f = open(filename)
Expand Down Expand Up @@ -235,25 +239,40 @@ def __init__(self, data_dir, label_dir):
def get_train_examples(self, filename='train.csv', text_col='text', label_col='label', size=-1):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))

return self._create_examples(data_df, "train", text_col=text_col, label_col=label_col)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
return self._create_examples(data_df.sample(size), "train", text_col=text_col, label_col=label_col)

def get_dev_examples(self, filename='val.csv', text_col='text', label_col='label', size=-1):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(data_df, "dev", text_col=text_col, label_col=label_col)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(data_df.sample(size), "dev", text_col=text_col, label_col=label_col)

def get_test_examples(self, filename='val.csv', text_col='text', label_col='label', size=-1):
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
if size == -1:
return self._create_examples(data_df, "test", text_col=text_col, label_col=None)
Expand All @@ -263,8 +282,12 @@ def get_test_examples(self, filename='val.csv', text_col='text', label_col='labe
def get_labels(self, filename='labels.csv'):
"""See base class."""
if self.labels == None:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
if ".xlsx" in filename:
self.labels = list(pd.read_excel(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
else:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
return self.labels

def _create_examples(self, df, set_type, text_col, label_col):
Expand Down Expand Up @@ -350,13 +373,13 @@ def load(data_dir, backend='nccl', filename="databunch.pkl"):
def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_file='val.csv', test_data=None,
label_file='labels.csv', text_col='text', label_col='label', bs=32, maxlen=512,
multi_gpu=True, multi_label=False, backend="nccl", model_type='bert', custom_sampler=None):

if isinstance(tokenizer, str):
_,_,tokenizer_class = MODEL_CLASSES[model_type]
# instantiate the new tokeniser object using the tokeniser name
tokenizer = tokenizer_class.from_pretrained(tokenizer, do_lower_case=('uncased' in tokenizer))

self.tokenizer = tokenizer
self.tokenizer = tokenizer
self.data_dir = data_dir
self.maxlen = maxlen
self.bs = bs
Expand Down Expand Up @@ -451,7 +474,7 @@ def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_f
torch.distributed.init_process_group(backend=backend,
init_method="tcp://localhost:23459",
rank=0, world_size=1)

except:
pass

Expand Down
47 changes: 36 additions & 11 deletions fast_bert/data_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,19 @@ def get_train_examples(
):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))

return self._create_examples(
data_df, "train", text_col=text_col, label_col=label_col
)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
return self._create_examples(
data_df.sample(size), "train", text_col=text_col, label_col=label_col
Expand All @@ -245,20 +251,29 @@ def get_dev_examples(
):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(
data_df, "dev", text_col=text_col, label_col=label_col
)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(
data_df.sample(size), "dev", text_col=text_col, label_col=label_col
)

def get_test_examples(
self, filename="val.csv", text_col="text", label_col="label", size=-1
):
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
if size == -1:
return self._create_examples(
Expand All @@ -272,11 +287,18 @@ def get_test_examples(
def get_labels(self, filename="labels.csv"):
"""See base class."""
if self.labels is None:
self.labels = list(
pd.read_csv(os.path.join(self.label_dir, filename), header=None)[0]
.astype("str")
.values
)
if ".xlsx" in filename:
self.labels = list(
pd.read_excel(os.path.join(self.label_dir, filename), header=None)[0]
.astype("str")
.values
)
else:
self.labels = list(
pd.read_csv(os.path.join(self.label_dir, filename), header=None)[0]
.astype("str")
.values
)
return self.labels

def _create_examples(self, df, set_type, text_col, label_col):
Expand Down Expand Up @@ -345,7 +367,10 @@ def __init__(self, data_dir, filename, text_col, label_col):
self.text_col = text_col
self.label_col = label_col

self.data = pd.read_csv(os.path.join(data_dir, filename))
if ".xlsx" in filename:
self.data = pd.read_excel(os.path.join(data_dir, filename))
else:
self.data = pd.read_csv(os.path.join(data_dir, filename))

def __getitem__(self, idx):
return self.data.loc[idx, self.text_col], self.data.loc[idx, self.label_col]
Expand Down