@@ -122,7 +122,10 @@ def mnli_eval(args):
122
122
)
123
123
print (f"Engine info: { text_classify .engine } " )
124
124
125
- label_map = _get_label2id (text_classify .config_path )
125
+ try :
126
+ label_map = _get_label2id (text_classify .config_path )
127
+ except KeyError :
128
+ label_map = {"entailment" : 0 , "neutral" : 1 , "contradiction" : 2 }
126
129
127
130
for idx , sample in _enumerate_progress (mnli_matched , args .max_samples ):
128
131
pred = text_classify ([[sample ["premise" ], sample ["hypothesis" ]]])
@@ -162,7 +165,10 @@ def qqp_eval(args):
162
165
)
163
166
print (f"Engine info: { text_classify .engine } " )
164
167
165
- label_map = _get_label2id (text_classify .config_path )
168
+ try :
169
+ label_map = _get_label2id (text_classify .config_path )
170
+ except KeyError :
171
+ label_map = {"not_duplicate" : 0 , "duplicate" : 1 , "LABEL_0" : 0 , "LABEL_1" : 1 }
166
172
167
173
for idx , sample in _enumerate_progress (qqp , args .max_samples ):
168
174
pred = text_classify ([[sample ["question1" ], sample ["question2" ]]])
@@ -193,7 +199,10 @@ def sst2_eval(args):
193
199
)
194
200
print (f"Engine info: { text_classify .engine } " )
195
201
196
- label_map = _get_label2id (text_classify .config_path )
202
+ try :
203
+ label_map = _get_label2id (text_classify .config_path )
204
+ except KeyError :
205
+ label_map = {"negative" : 0 , "positive" : 1 , "LABEL_0" : 0 , "LABEL_1" : 1 }
197
206
198
207
for idx , sample in _enumerate_progress (sst2 , args .max_samples ):
199
208
pred = text_classify (
0 commit comments