Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit fb673ec

Browse files
committed
eval_downstream.py
1 parent c581741 commit fb673ec

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/deepsparse/transformers/eval_downstream.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ def mnli_eval(args):
122122
)
123123
print(f"Engine info: {text_classify.engine}")
124124

125-
label_map = _get_label2id(text_classify.config_path)
125+
try:
126+
label_map = _get_label2id(text_classify.config_path)
127+
except KeyError:
128+
label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}
126129

127130
for idx, sample in _enumerate_progress(mnli_matched, args.max_samples):
128131
pred = text_classify([[sample["premise"], sample["hypothesis"]]])
@@ -162,7 +165,10 @@ def qqp_eval(args):
162165
)
163166
print(f"Engine info: {text_classify.engine}")
164167

165-
label_map = _get_label2id(text_classify.config_path)
168+
try:
169+
label_map = _get_label2id(text_classify.config_path)
170+
except KeyError:
171+
label_map = {"not_duplicate": 0, "duplicate": 1, "LABEL_0": 0, "LABEL_1": 1}
166172

167173
for idx, sample in _enumerate_progress(qqp, args.max_samples):
168174
pred = text_classify([[sample["question1"], sample["question2"]]])
@@ -193,7 +199,10 @@ def sst2_eval(args):
193199
)
194200
print(f"Engine info: {text_classify.engine}")
195201

196-
label_map = _get_label2id(text_classify.config_path)
202+
try:
203+
label_map = _get_label2id(text_classify.config_path)
204+
except KeyError:
205+
label_map = {"negative": 0, "positive": 1, "LABEL_0": 0, "LABEL_1": 1}
197206

198207
for idx, sample in _enumerate_progress(sst2, args.max_samples):
199208
pred = text_classify(

0 commit comments

Comments
 (0)