2
2
import argparse
3
3
import torch
4
4
5
- from onmt .modules .position_ffn import ActivationFunction
6
-
7
-
8
- def get_ctranslate2_model_spec (opt ):
9
- """Creates a CTranslate2 model specification from the model options."""
10
- with_relative_position = getattr (opt , "max_relative_positions" , 0 ) > 0
11
- relu = ActivationFunction .relu
12
- is_ct2_compatible = (
13
- opt .encoder_type == "transformer"
14
- and opt .decoder_type == "transformer"
15
- and not getattr (opt , "aan_useffn" , False )
16
- and getattr (opt , "self_attn_type" , "scaled-dot" ) == "scaled-dot"
17
- and getattr (opt , "pos_ffn_activation_fn" , relu ) == relu
18
- and ((opt .position_encoding and not with_relative_position )
19
- or (with_relative_position and not opt .position_encoding )))
20
- if not is_ct2_compatible :
21
- return None
22
- import ctranslate2
23
- num_heads = getattr (opt , "heads" , 8 )
24
- return ctranslate2 .specs .TransformerSpec (
25
- (opt .enc_layers , opt .dec_layers ),
26
- num_heads ,
27
- with_relative_position = with_relative_position )
28
-
29
5
30
6
def main ():
31
7
parser = argparse .ArgumentParser (
@@ -49,14 +25,13 @@ def main():
49
25
model ["optim" ] = None
50
26
torch .save (model , opt .output )
51
27
elif opt .format == "ctranslate2" :
52
- model_spec = get_ctranslate2_model_spec (model ["opt" ])
53
- if model_spec is None :
54
- raise ValueError ("This model is not supported by CTranslate2. Go "
55
- "to https://github.com/OpenNMT/CTranslate2 for "
56
- "more information on supported models." )
57
28
import ctranslate2
29
+ if not hasattr (ctranslate2 , "__version__" ):
30
+ raise RuntimeError (
31
+ "onmt_release_model script requires ctranslate2 >= 2.0.0"
32
+ )
58
33
converter = ctranslate2 .converters .OpenNMTPyConverter (opt .model )
59
- converter .convert (opt .output , model_spec , force = True ,
34
+ converter .convert (opt .output , force = True ,
60
35
quantization = opt .quantization )
61
36
62
37
0 commit comments