Skip to content

Commit c6f47c6

Browse files
authored
[Backend] integrate x2paddle tool (#1148)
* integrate x2paddle tool * fix caffe convertion bug * add post support * fix tar file path * add download api * add tips for users
1 parent 4956520 commit c6f47c6

File tree

5 files changed

+415
-0
lines changed

5 files changed

+415
-0
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ matplotlib
1010
pandas
1111
multiprocess
1212
packaging
13+
x2paddle
14+
rarfile
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =======================================================================
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# =======================================================================
15+
import base64
16+
import json
17+
import os
18+
import tempfile
19+
from collections import deque
20+
from threading import Lock
21+
22+
from flask import request
23+
from x2paddle.convert import caffe2paddle
24+
from x2paddle.convert import onnx2paddle
25+
26+
from .xarfile import archive
27+
from .xarfile import unarchive
28+
from visualdl.server.api import gen_result
29+
from visualdl.server.api import result
30+
31+
32+
class ModelConvertApi(object):
33+
def __init__(self):
34+
self.supported_formats = {'onnx', 'caffe'}
35+
self.lock = Lock()
36+
self.translated_models = deque(
37+
maxlen=5) # used to store user's translated model for download
38+
self.request_id = 0 # used to store user's request
39+
40+
@result()
41+
def convert_model(self, format):
42+
file_handle = request.files['file']
43+
data = file_handle.stream.read()
44+
if format not in self.supported_formats:
45+
raise RuntimeError('Model format {} is not supported. \
46+
Only onnx and caffe models are supported now.'.format(format))
47+
result = {}
48+
result['from'] = format
49+
result['to'] = 'paddle'
50+
# call x2paddle to convert models
51+
with tempfile.TemporaryDirectory(
52+
suffix='x2paddle_translated_models') as tmpdirname:
53+
with tempfile.NamedTemporaryFile() as fp:
54+
fp.write(data)
55+
fp.flush()
56+
try:
57+
if format == 'onnx':
58+
try:
59+
import onnx # noqa: F401
60+
except Exception:
61+
raise RuntimeError(
62+
"[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\"."
63+
)
64+
onnx2paddle(fp.name, tmpdirname)
65+
elif format == 'caffe':
66+
with tempfile.TemporaryDirectory() as unarchivedir:
67+
unarchive(fp.name, unarchivedir)
68+
prototxt_path = None
69+
weight_path = None
70+
for dirname, subdirs, filenames in os.walk(
71+
unarchivedir):
72+
for filename in filenames:
73+
if '.prototxt' in filename:
74+
prototxt_path = os.path.join(
75+
dirname, filename)
76+
if '.caffemodel' in filename:
77+
weight_path = os.path.join(
78+
dirname, filename)
79+
if prototxt_path is None or weight_path is None:
80+
raise RuntimeError(
81+
".prototxt or .caffemodel file is missing in your archive file, \
82+
please check files uploaded.")
83+
caffe2paddle(prototxt_path, weight_path,
84+
tmpdirname, None)
85+
except Exception as e:
86+
raise RuntimeError(
87+
"[Convertion error] {}.\n Please open an issue at \
88+
https://github.com/PaddlePaddle/X2Paddle/issues to report your problem."
89+
.format(e))
90+
with self.lock:
91+
origin_dir = os.getcwd()
92+
os.chdir(os.path.dirname(tmpdirname))
93+
archive_path = os.path.join(
94+
os.path.dirname(tmpdirname),
95+
archive(os.path.basename(tmpdirname)))
96+
os.chdir(origin_dir)
97+
result['request_id'] = self.request_id
98+
self.request_id += 1
99+
with open(archive_path, 'rb') as archive_fp:
100+
self.translated_models.append((result['request_id'],
101+
archive_fp.read()))
102+
with open(
103+
os.path.join(tmpdirname, 'inference_model',
104+
'model.pdmodel'), 'rb') as model_fp:
105+
model_encoded = base64.b64encode(
106+
model_fp.read()).decode('utf-8')
107+
result['pdmodel'] = model_encoded
108+
if os.path.exists(archive_path):
109+
os.remove(archive_path)
110+
111+
return result
112+
113+
@result('application/octet-stream')
114+
def download_model(self, request_id):
115+
for stored_request_id, data in self.translated_models:
116+
if str(stored_request_id) == request_id:
117+
return data
118+
119+
120+
def create_model_convert_api_call():
121+
api = ModelConvertApi()
122+
routes = {
123+
'convert': (api.convert_model, ['format']),
124+
'download': (api.download_model, ['request_id'])
125+
}
126+
127+
def call(path: str, args):
128+
route = routes.get(path)
129+
if not route:
130+
return json.dumps(gen_result(
131+
status=1, msg='api not found')), 'application/json', None
132+
method, call_arg_names = route
133+
call_args = [args.get(name) for name in call_arg_names]
134+
return method(*call_args)
135+
136+
return call

0 commit comments

Comments
 (0)