|
1 | 1 | import argparse
|
2 |
| -import json |
3 | 2 |
|
4 |
| -import torch |
5 | 3 | from torch.autograd import Variable
|
| 4 | +from tqdm import tqdm |
| 5 | + |
| 6 | +from decoder import GreedyDecoder, BeamCTCDecoder |
6 | 7 |
|
7 | 8 | from data.data_loader import SpectrogramDataset, AudioDataLoader
|
8 |
| -from decoder import GreedyDecoder, BeamCTCDecoder, Scorer, KenLMScorer |
9 | 9 | from model import DeepSpeech
|
10 | 10 |
|
11 | 11 | parser = argparse.ArgumentParser(description='DeepSpeech prediction')
|
|
19 | 19 | parser.add_argument('--decoder', default="greedy", choices=["greedy", "beam"], type=str, help="Decoder to use")
|
20 | 20 | beam_args = parser.add_argument_group("Beam Decode Options", "Configurations options for the CTC Beam Search decoder")
|
21 | 21 | beam_args.add_argument('--beam_width', default=10, type=int, help='Beam width to use')
|
22 |
| -beam_args.add_argument('--lm_path', default=None, type=str, help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') |
23 |
| -beam_args.add_argument('--trie_path', default=None, type=str, help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') |
| 22 | +beam_args.add_argument('--lm_path', default=None, type=str, |
| 23 | + help='Path to an (optional) kenlm language model for use with beam search (req\'d with trie)') |
| 24 | +beam_args.add_argument('--trie_path', default=None, type=str, |
| 25 | + help='Path to an (optional) trie dictionary for use with beam search (req\'d with LM)') |
24 | 26 | beam_args.add_argument('--lm_alpha', default=0.8, type=float, help='Language model weight')
|
25 | 27 | beam_args.add_argument('--lm_beta1', default=1, type=float, help='Language model word bonus (all words)')
|
26 | 28 | beam_args.add_argument('--lm_beta2', default=1, type=float, help='Language model word bonus (IV words)')
|
|
34 | 36 | audio_conf = DeepSpeech.get_audio_conf(model)
|
35 | 37 |
|
36 | 38 | if args.decoder == "beam":
|
37 |
| - scorer = None |
38 |
| - if args.lm_path is not None: |
39 |
| - scorer = KenLMScorer(labels, args.lm_path, args.trie_path) |
40 |
| - scorer.set_lm_weight(args.lm_alpha) |
41 |
| - scorer.set_word_weight(args.lm_beta1) |
42 |
| - scorer.set_valid_word_weight(args.lm_beta2) |
43 |
| - else: |
44 |
| - scorer = Scorer() |
45 |
| - decoder = BeamCTCDecoder(labels, scorer, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), blank_index=labels.index('_')) |
| 39 | + decoder = BeamCTCDecoder(labels, beam_width=args.beam_width, top_paths=1, space_index=labels.index(' '), |
| 40 | + blank_index=labels.index('_'), lm_path=args.lm_path, |
| 41 | + trie_path=args.trie_path, lm_alpha=args.lm_alpha, lm_beta1=args.lm_beta1, |
| 42 | + lm_beta2=args.lm_beta2) |
46 | 43 | else:
|
47 | 44 | decoder = GreedyDecoder(labels, space_index=labels.index(' '), blank_index=labels.index('_'))
|
48 | 45 |
|
|
51 | 48 | test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
|
52 | 49 | num_workers=args.num_workers)
|
53 | 50 | total_cer, total_wer = 0, 0
|
54 |
| - for i, (data) in enumerate(test_loader): |
| 51 | + for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): |
55 | 52 | inputs, targets, input_percentages, target_sizes = data
|
56 | 53 |
|
57 | 54 | inputs = Variable(inputs, volatile=True)
|
|
0 commit comments