1
- from benchmarkstt .schema import Schema
1
+ from benchmarkstt .schema import Schema , Item
2
2
import logging
3
- from benchmarkstt .diff import Differ
3
+ from benchmarkstt .diff import Differ , factory as differ_factory
4
4
from benchmarkstt .diff .core import RatcliffObershelp
5
5
from benchmarkstt .diff .formatter import format_diff
6
6
from benchmarkstt .metrics import Metric
12
12
OpcodeCounts = namedtuple ('OpcodeCounts' ,
13
13
('equal' , 'replace' , 'insert' , 'delete' ))
14
14
15
+ type_schema = Union [Schema , list ]
16
+
15
17
16
18
def traversible (schema , key = None ):
17
19
if key is None :
18
20
key = 'item'
19
- return [word [key ] for word in schema ]
21
+ return [item if type ( item ) is str else item [key ] for item in schema ]
20
22
21
23
22
24
def get_differ (a , b , differ_class : Differ ):
@@ -41,7 +43,7 @@ def __init__(self, differ_class: Differ = None, dialect: str = None):
41
43
self ._differ_class = differ_class
42
44
self ._dialect = dialect
43
45
44
- def compare (self , ref : Schema , hyp : Schema ):
46
+ def compare (self , ref : type_schema , hyp : type_schema ):
45
47
differ = get_differ (ref , hyp , differ_class = self ._differ_class )
46
48
a = traversible (ref )
47
49
b = traversible (hyp )
@@ -82,24 +84,17 @@ class WER(Metric):
82
84
INS_PENALTY = 1
83
85
SUB_PENALTY = 1
84
86
85
- def __init__ (self , mode = None , differ_class : Differ = None ):
87
+ def __init__ (self , mode = None , differ_class : Union [ str , Differ , None ] = None ):
86
88
self ._mode = mode
87
89
88
90
if differ_class is None :
89
91
differ_class = RatcliffObershelp
90
92
self ._differ_class = differ_class
93
+
91
94
if mode == self .MODE_HUNT :
92
95
self .DEL_PENALTY = self .INS_PENALTY = .5
93
96
94
- def compare (self , ref : Schema , hyp : Schema ) -> float :
95
- if self ._mode == self .MODE_LEVENSHTEIN :
96
- ref_list = [i ['item' ] for i in ref ]
97
- hyp_list = [i ['item' ] for i in hyp ]
98
- total_ref = len (ref_list )
99
- if total_ref == 0 :
100
- return 1
101
- return editdistance .eval (ref_list , hyp_list ) / total_ref
102
-
97
+ def compare (self , ref : type_schema , hyp : type_schema ) -> float :
103
98
diffs = get_differ (ref , hyp , differ_class = self ._differ_class )
104
99
105
100
counts = diffs .get_opcode_counts ()
@@ -141,25 +136,21 @@ class CER(Metric):
141
136
will first be split into words, ['aa','bb','cc'], and
142
137
then merged into a final string for evaluation: 'aabbcc'.
143
138
144
- :param mode: 'levenshtein' (default).
145
- :param differ_class: For future use.
139
+ :param differ_class: see :py:mod:`benchmarkstt.Differ.core`
146
140
"""
147
141
148
- # CER modes
149
- MODE_LEVENSHTEIN = 'levenshtein'
142
+ def __init__ ( self , differ_class : Union [ str , Differ , None ] = None ):
143
+ self . _differ_class = Levenshtein if differ_class is None else differ_class
150
144
151
- def __init__ (self , mode = None , differ_class = None ):
152
- self ._mode = mode
153
-
154
- def compare (self , ref : Schema , hyp : Schema ):
155
- ref_str = '' .join ([i ['item' ] for i in ref ])
156
- hyp_str = '' .join ([i ['item' ] for i in hyp ])
157
- total_ref = len (ref_str )
145
+ def compare (self , ref : type_schema , hyp : type_schema ):
146
+ ref_str = '' .join (traversible (ref ))
147
+ hyp_str = '' .join (traversible (hyp ))
158
148
159
- if total_ref == 0 :
149
+ if len ( ref_str ) == 0 :
160
150
return 0 if len (hyp_str ) == 0 else 1
161
151
162
- return editdistance .eval (ref_str , hyp_str ) / total_ref
152
+ diffs = get_differ (ref_str , hyp_str , differ_class = self ._differ_class )
153
+ return diffs .get_error_rate ()
163
154
164
155
165
156
class DiffCounts (Metric ):
@@ -169,12 +160,10 @@ class DiffCounts(Metric):
169
160
:param differ_class: see :py:mod:`benchmarkstt.Differ.core`
170
161
"""
171
162
172
- def __init__ (self , differ_class : Differ = None ):
173
- if differ_class is None :
174
- differ_class = RatcliffObershelp
163
+ def __init__ (self , differ_class : Union [str , Differ , None ] = None ):
175
164
self ._differ_class = differ_class
176
165
177
- def compare (self , ref : Schema , hyp : Schema ) -> OpcodeCounts :
166
+ def compare (self , ref : type_schema , hyp : type_schema ) -> OpcodeCounts :
178
167
diffs = get_differ (ref , hyp , differ_class = self ._differ_class )
179
168
return diffs .get_opcode_counts ()
180
169
0 commit comments