Skip to content

Commit 20b9570

Browse files
committed
get_error_rate for WER
1 parent 950c081 commit 20b9570

File tree

7 files changed

+62
-47
lines changed

7 files changed

+62
-47
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def filter_requirements(line):
6262
'jsonrpcserver>=4.0.1',
6363
'gunicorn>=19.9.0',
6464
'docutils>=0.14',
65+
'edit_distance>=1.0.4',
6566
'editdistance>=0.5.3',
6667
'Unidecode>=1.1.2',
6768
],

src/benchmarkstt/diff/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,22 @@ def get_opcodes(self):
6464
def get_opcode_counts(self):
6565
raise NotImplementedError()
6666

67+
@abstractmethod
68+
def get_error_rate(self):
69+
raise NotImplementedError()
70+
71+
72+
class Differ(DifferInterface, metaclass=ABCMeta):
73+
def get_opcode_counts(self):
74+
return get_opcode_counts(self.get_opcodes())
75+
76+
def get_error_rate(self):
77+
counts = self.get_opcode_counts()
78+
79+
changes = counts.replace + counts.delete + counts.insert
80+
total = counts.equal + counts.replace + counts.delete
81+
82+
return changes / total
83+
84+
6785
factory = CoreFactory(DifferInterface, False)

src/benchmarkstt/diff/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from difflib import SequenceMatcher
66
from benchmarkstt.diff import Differ
77
import edit_distance
8+
import editdistance
89

910

1011
class RatcliffObershelp(Differ):
@@ -55,6 +56,11 @@ def __init__(self, a, b, **kwargs):
5556
def get_opcodes(self):
5657
return self.simplify_opcodes(self._matcher.get_opcodes())
5758

59+
def get_error_rate(self):
60+
a = self._kwargs['a']
61+
b = self._kwargs['b']
62+
return editdistance.eval(a, b) / len(a)
63+
5864
@staticmethod
5965
def simplify_opcodes(opcodes):
6066
new_codes = []

src/benchmarkstt/metrics/core.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from benchmarkstt.schema import Schema
1+
from benchmarkstt.schema import Schema, Item
22
import logging
3-
from benchmarkstt.diff import Differ
3+
from benchmarkstt.diff import Differ, factory as differ_factory
44
from benchmarkstt.diff.core import RatcliffObershelp
55
from benchmarkstt.diff.formatter import format_diff
66
from benchmarkstt.metrics import Metric
@@ -12,11 +12,13 @@
1212
OpcodeCounts = namedtuple('OpcodeCounts',
1313
('equal', 'replace', 'insert', 'delete'))
1414

15+
type_schema = Union[Schema, list]
16+
1517

1618
def traversible(schema, key=None):
1719
if key is None:
1820
key = 'item'
19-
return [word[key] for word in schema]
21+
return [item if type(item) is str else item[key] for item in schema]
2022

2123

2224
def get_differ(a, b, differ_class: Differ):
@@ -41,7 +43,7 @@ def __init__(self, differ_class: Differ = None, dialect: str = None):
4143
self._differ_class = differ_class
4244
self._dialect = dialect
4345

44-
def compare(self, ref: Schema, hyp: Schema):
46+
def compare(self, ref: type_schema, hyp: type_schema):
4547
differ = get_differ(ref, hyp, differ_class=self._differ_class)
4648
a = traversible(ref)
4749
b = traversible(hyp)
@@ -82,24 +84,17 @@ class WER(Metric):
8284
INS_PENALTY = 1
8385
SUB_PENALTY = 1
8486

85-
def __init__(self, mode=None, differ_class: Differ = None):
87+
def __init__(self, mode=None, differ_class: Union[str, Differ, None] = None):
8688
self._mode = mode
8789

8890
if differ_class is None:
8991
differ_class = RatcliffObershelp
9092
self._differ_class = differ_class
93+
9194
if mode == self.MODE_HUNT:
9295
self.DEL_PENALTY = self.INS_PENALTY = .5
9396

94-
def compare(self, ref: Schema, hyp: Schema) -> float:
95-
if self._mode == self.MODE_LEVENSHTEIN:
96-
ref_list = [i['item'] for i in ref]
97-
hyp_list = [i['item'] for i in hyp]
98-
total_ref = len(ref_list)
99-
if total_ref == 0:
100-
return 1
101-
return editdistance.eval(ref_list, hyp_list) / total_ref
102-
97+
def compare(self, ref: type_schema, hyp: type_schema) -> float:
10398
diffs = get_differ(ref, hyp, differ_class=self._differ_class)
10499

105100
counts = diffs.get_opcode_counts()
@@ -141,25 +136,21 @@ class CER(Metric):
141136
will first be split into words, ['aa','bb','cc'], and
142137
then merged into a final string for evaluation: 'aabbcc'.
143138
144-
:param mode: 'levenshtein' (default).
145-
:param differ_class: For future use.
139+
:param differ_class: see :py:mod:`benchmarkstt.Differ.core`
146140
"""
147141

148-
# CER modes
149-
MODE_LEVENSHTEIN = 'levenshtein'
142+
def __init__(self, differ_class: Union[str, Differ, None] = None):
143+
self._differ_class = Levenshtein if differ_class is None else differ_class
150144

151-
def __init__(self, mode=None, differ_class=None):
152-
self._mode = mode
153-
154-
def compare(self, ref: Schema, hyp: Schema):
155-
ref_str = ''.join([i['item'] for i in ref])
156-
hyp_str = ''.join([i['item'] for i in hyp])
157-
total_ref = len(ref_str)
145+
def compare(self, ref: type_schema, hyp: type_schema):
146+
ref_str = ''.join(traversible(ref))
147+
hyp_str = ''.join(traversible(hyp))
158148

159-
if total_ref == 0:
149+
if len(ref_str) == 0:
160150
return 0 if len(hyp_str) == 0 else 1
161151

162-
return editdistance.eval(ref_str, hyp_str) / total_ref
152+
diffs = get_differ(ref_str, hyp_str, differ_class=self._differ_class)
153+
return diffs.get_error_rate()
163154

164155

165156
class DiffCounts(Metric):
@@ -169,12 +160,10 @@ class DiffCounts(Metric):
169160
:param differ_class: see :py:mod:`benchmarkstt.Differ.core`
170161
"""
171162

172-
def __init__(self, differ_class: Differ = None):
173-
if differ_class is None:
174-
differ_class = RatcliffObershelp
163+
def __init__(self, differ_class: Union[str, Differ, None] = None):
175164
self._differ_class = differ_class
176165

177-
def compare(self, ref: Schema, hyp: Schema) -> OpcodeCounts:
166+
def compare(self, ref: type_schema, hyp: type_schema) -> OpcodeCounts:
178167
diffs = get_differ(ref, hyp, differ_class=self._differ_class)
179168
return diffs.get_opcode_counts()
180169

src/benchmarkstt/schema.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
import json
55
from collections.abc import Mapping
6-
from typing import Union
76
from collections import defaultdict
87

98

@@ -51,6 +50,9 @@ def __iter__(self):
5150
def __repr__(self):
5251
return 'Item(%s)' % (self.json(),)
5352

53+
def __hash__(self):
54+
return hash(self._val['item'])
55+
5456
def json(self, **kwargs):
5557
return Schema.dumps(self, **kwargs)
5658

tests/benchmarkstt/test_diff.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def clean_opcodes(opcodes):
2525
def test_simple_levenshtein_ratcliff_similarity():
2626
a = list('012345')
2727
b = list('023x45')
28-
assert clean_opcodes(Levenshtein(a, b).get_opcodes()) == \
29-
clean_opcodes(RatcliffObershelp(a, b).get_opcodes())
28+
assert(clean_opcodes(Levenshtein(a, b).get_opcodes()) ==
29+
clean_opcodes(RatcliffObershelp(a, b).get_opcodes()))
3030

3131

3232
@differs_decorator
@@ -35,42 +35,43 @@ def test_simple(differ):
3535
'0123456HIJkopq',
3636
'0123456HIJKlmnopq'
3737
)
38-
assert clean_opcodes(sm.get_opcodes()) == \
38+
assert(clean_opcodes(sm.get_opcodes()) ==
3939
clean_opcodes([('equal', 0, 10, 0, 10),
4040
('replace', 10, 11, 10, 14),
41-
('equal', 11, 14, 14, 17)])
41+
('equal', 11, 14, 14, 17)]))
4242

4343

4444
@differs_decorator
4545
def test_one_insert(differ):
4646
sm = differ('b' * 100, 'a' + 'b' * 100)
47-
assert clean_opcodes(sm.get_opcodes()) == \
47+
assert(clean_opcodes(sm.get_opcodes()) ==
4848
clean_opcodes([('insert', 0, 0, 0, 1),
49-
('equal', 0, 100, 1, 101)])
49+
('equal', 0, 100, 1, 101)]))
50+
5051
sm = differ('b' * 100, 'b' * 50 + 'a' + 'b' * 50)
51-
assert clean_opcodes(sm.get_opcodes()) == \
52+
assert(clean_opcodes(sm.get_opcodes()) ==
5253
clean_opcodes([('equal', 0, 50, 0, 50),
5354
('insert', 50, 50, 50, 51),
54-
('equal', 50, 100, 51, 101)])
55+
('equal', 50, 100, 51, 101)]))
5556

5657

5758
@differs_decorator
5859
def test_one_delete(differ):
5960
sm = differ('a' * 40 + 'c' + 'b' * 40, 'a' * 40 + 'b' * 40)
60-
assert clean_opcodes(sm.get_opcodes()) == \
61+
assert(clean_opcodes(sm.get_opcodes()) ==
6162
clean_opcodes([('equal', 0, 40, 0, 40),
6263
('delete', 40, 41, 40, 40),
63-
('equal', 41, 81, 40, 80)])
64+
('equal', 41, 81, 40, 80)]))
6465

6566

6667
def test_ratcliffobershelp():
6768
ref = "a b c d e f"
6869
hyp = "a b d e kfmod fgdjn idf giudfg diuf dufg idgiudgd"
6970
sm = RatcliffObershelp(ref, hyp)
70-
assert clean_opcodes(sm.get_opcodes()) == \
71+
assert(clean_opcodes(sm.get_opcodes()) ==
7172
clean_opcodes([('equal', 0, 3, 0, 3),
7273
('delete', 3, 5, 3, 3),
7374
('equal', 5, 10, 3, 8),
7475
('insert', 10, 10, 8, 9),
7576
('equal', 10, 11, 9, 10),
76-
('insert', 11, 11, 10, 49)])
77+
('insert', 11, 11, 10, 49)]))

tests/benchmarkstt/test_metrics_core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def test_wer(a, b, exp):
4141
assert WER(differ_class='levenshtein').compare(PlainText(a), PlainText(b)) == wer_levenshtein
4242

4343

44-
4544
@pytest.mark.parametrize('a,b,entities_list,weights,exp_beer,exp_occ', [
4645
['madam is here', 'adam is here', ['madam', 'here'], [100, 10], (1.0, 0.0), (1, 1)],
4746
['theresa may is here', 'theresa may is there', ['theresa may', 'here'], [10, 100], (0.0, 1.0), (1, 1)],
@@ -113,5 +112,4 @@ def test_wa_beer(a, b, entities_list, weights, exp):
113112
def test_cer(a, b, exp):
114113
cer_levenshtein, = exp
115114

116-
assert CER(mode=CER.MODE_LEVENSHTEIN).compare(PlainText(a), PlainText(b)) == cer_levenshtein
117-
115+
assert CER(differ_class='levenshtein').compare(PlainText(a), PlainText(b)) == cer_levenshtein

0 commit comments

Comments
 (0)