Skip to content

Commit fc8efd8

Browse files
author
Sean Naren
authored
Merge pull request #141 from SeanNaren/fix-inference
Refactor of testing/prediction, added progress
2 parents 3f10ec0 + 683eabe commit fc8efd8

File tree

4 files changed

+38
-41
lines changed

4 files changed

+38
-41
lines changed

decoder.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,8 @@
1717

1818
import Levenshtein as Lev
1919
import torch
20-
from enum import Enum
2120
from six.moves import xrange
2221

23-
try:
24-
from pytorch_ctc import CTCBeamDecoder as CTCBD
25-
from pytorch_ctc import Scorer, KenLMScorer
26-
except ImportError:
27-
print("warn: pytorch_ctc unavailable. Only greedy decoding is supported.")
28-
2922

3023
class Decoder(object):
3124
"""
@@ -134,17 +127,25 @@ def decode(self, probs, sizes=None):
134127

135128

136129
class BeamCTCDecoder(Decoder):
137-
def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28):
130+
def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28, lm_path=None, trie_path=None,
131+
lm_alpha=None, lm_beta1=None, lm_beta2=None):
138132
super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index)
139133
self._beam_width = beam_width
140134
self._top_n = top_paths
135+
141136
try:
142-
import pytorch_ctc
137+
from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer
143138
except ImportError:
144139
raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")
145-
146-
self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width,
147-
blank_index=blank_index, space_index=space_index, merge_repeated=False)
140+
if lm_path is not None:
141+
scorer = KenLMScorer(labels, lm_path, trie_path)
142+
scorer.set_lm_weight(lm_alpha)
143+
scorer.set_word_weight(lm_beta1)
144+
scorer.set_valid_word_weight(lm_beta2)
145+
else:
146+
scorer = Scorer()
147+
self._decoder = CTCBeamDecoder(scorer, labels, top_paths=top_paths, beam_width=beam_width,
148+
blank_index=blank_index, space_index=space_index, merge_repeated=False)
148149

149150
def decode(self, probs, sizes=None):
150151
sizes = sizes.cpu() if sizes is not None else None

predict.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import sys
33
import time
44

5-
import torch
5+
from decoder import GreedyDecoder, BeamCTCDecoder
6+
67
from torch.autograd import Variable
78

89
from data.data_loader import SpectrogramParser
9-
from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer
1010
from model import DeepSpeech
1111

1212
parser = argparse.ArgumentParser(description='DeepSpeech prediction')
@@ -18,8 +18,10 @@
1818
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
1919
beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
2020
beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
21-
beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
22-
beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
21+
beam_args.add_argument('--lm_path', default=None, type=str,
22+
help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
23+
beam_args.add_argument('--trie_path', default=None, type=str,
24+
help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
2325
beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
2426
beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
2527
beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
@@ -33,15 +35,10 @@
3335
audio_conf = DeepSpeech.get_audio_conf(model)
3436

3537
if args.decoder == "beam":
36-
scorer = None
37-
if args.lm_path is not None:
38-
scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
39-
scorer.set_lm_weight(args.lm_alpha)
40-
scorer.set_word_weight(args.lm_beta1)
41-
scorer.set_valid_word_weight(args.lm_beta2)
42-
else:
43-
scorer = Scorer()
44-
decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
38+
decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '),
39+
blank_index=labels.index('_'), lm_path=args.lm_path,
40+
trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1,
41+
lm_beta2=args.lm_beta2)
4542
else:
4643
decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))
4744

@@ -56,4 +53,5 @@
5653
t1 = time.time()
5754

5855
print(decoded_output[0])
59-
print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3)*audio_conf['window_stride'], t1-t0), file=sys.stderr)
56+
print("Decoded {0:.2f} seconds of audio in {1:.2f} seconds".format(spect.size(3) * audio_conf['window_stride'],
57+
t1 - t0), file=sys.stderr)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ torch
33
visdom
44
wget
55
librosa
6+
tqdm

test.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import argparse
2-
import json
32

4-
import torch
53
from torch.autograd import Variable
4+
from tqdm import tqdm
5+
6+
from decoder import GreedyDecoder, BeamCTCDecoder
67

78
from data.data_loader import SpectrogramDataset, AudioDataLoader
8-
from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer
99
from model import DeepSpeech
1010

1111
parser = argparse.ArgumentParser(description='DeepSpeech prediction')
@@ -19,8 +19,10 @@
1919
parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
2020
beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
2121
beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
22-
beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
23-
beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
22+
beam_args.add_argument('--lm_path', default=None, type=str,
23+
help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)')
24+
beam_args.add_argument('--trie_path', default=None, type=str,
25+
help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)')
2426
beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
2527
beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
2628
beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
@@ -34,15 +36,10 @@
3436
audio_conf = DeepSpeech.get_audio_conf(model)
3537

3638
if args.decoder == "beam":
37-
scorer = None
38-
if args.lm_path is not None:
39-
scorer = KenLMScorer(labels, args.lm_path, args.trie_path)
40-
scorer.set_lm_weight(args.lm_alpha)
41-
scorer.set_word_weight(args.lm_beta1)
42-
scorer.set_valid_word_weight(args.lm_beta2)
43-
else:
44-
scorer = Scorer()
45-
decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_'))
39+
decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '),
40+
blank_index=labels.index('_'), lm_path=args.lm_path,
41+
trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1,
42+
lm_beta2=args.lm_beta2)
4643
else:
4744
decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))
4845

@@ -51,7 +48,7 @@
5148
test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
5249
num_workers=args.num_workers)
5350
total_cer, total_wer = 0, 0
54-
for i, (data) in enumerate(test_loader):
51+
for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
5552
inputs, targets, input_percentages, target_sizes = data
5653

5754
inputs = Variable(inputs, volatile=True)

0 commit comments

Comments
 (0)