Skip to content

Commit 54c777a

Browse files
authored
Update CTranslate2 usage for 2.0 (#2071)
1 parent d5d3c74 commit 54c777a

File tree

2 files changed

+7
-31
lines changed

2 files changed

+7
-31
lines changed

onmt/bin/release_model.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,6 @@
22
import argparse
33
import torch
44

5-
from onmt.modules.position_ffn import ActivationFunction
6-
7-
8-
def get_ctranslate2_model_spec(opt):
9-
"""Creates a CTranslate2 model specification from the model options."""
10-
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
11-
relu = ActivationFunction.relu
12-
is_ct2_compatible = (
13-
opt.encoder_type == "transformer"
14-
and opt.decoder_type == "transformer"
15-
and not getattr(opt, "aan_useffn", False)
16-
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
17-
and getattr(opt, "pos_ffn_activation_fn", relu) == relu
18-
and ((opt.position_encoding and not with_relative_position)
19-
or (with_relative_position and not opt.position_encoding)))
20-
if not is_ct2_compatible:
21-
return None
22-
import ctranslate2
23-
num_heads = getattr(opt, "heads", 8)
24-
return ctranslate2.specs.TransformerSpec(
25-
(opt.enc_layers, opt.dec_layers),
26-
num_heads,
27-
with_relative_position=with_relative_position)
28-
295

306
def main():
317
parser = argparse.ArgumentParser(
@@ -49,14 +25,13 @@ def main():
4925
model["optim"] = None
5026
torch.save(model, opt.output)
5127
elif opt.format == "ctranslate2":
52-
model_spec = get_ctranslate2_model_spec(model["opt"])
53-
if model_spec is None:
54-
raise ValueError("This model is not supported by CTranslate2. Go "
55-
"to https://github.com/OpenNMT/CTranslate2 for "
56-
"more information on supported models.")
5728
import ctranslate2
29+
if not hasattr(ctranslate2, "__version__"):
30+
raise RuntimeError(
31+
"onmt_release_model script requires ctranslate2 >= 2.0.0"
32+
)
5833
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
59-
converter.convert(opt.output, model_spec, force=True,
34+
converter.convert(opt.output, force=True,
6035
quantization=opt.quantization)
6136

6237

onmt/translate/translation_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def translate(self, texts_to_translate, batch_size=8, tgt=None):
111111
target_prefix=tgt if self.target_prefix else None,
112112
max_batch_size=self.batch_size,
113113
beam_size=self.beam_size,
114-
num_hypotheses=self.n_best
114+
num_hypotheses=self.n_best,
115+
return_scores=True,
115116
)
116117
scores = [[item["score"] for item in ex] for ex in preds]
117118
predictions = [[" ".join(item["tokens"]) for item in ex]

0 commit comments

Comments
 (0)