15
15
class PythonCodeGen (CodeGenBase ):
16
16
"""Generation of Python code for a given Schema Salad definition."""
17
17
def __init__ (self , out ):
18
- # type: (IO[str ]) -> None
18
+ # type: (IO[Any ]) -> None
19
19
super (PythonCodeGen , self ).__init__ ()
20
20
self .out = out
21
21
self .current_class_is_abstract = False
@@ -65,7 +65,7 @@ def begin_class(self, # pylint: disable=too-many-arguments
65
65
else :
66
66
ext = "Savable"
67
67
68
- self .out .write ("class %s(%s):\n " % (self . safe_name ( classname ) , ext ))
68
+ self .out .write ("class %s(%s):\n " % (classname , ext ))
69
69
70
70
if doc :
71
71
self .out .write (' """\n ' )
@@ -79,30 +79,55 @@ def begin_class(self, # pylint: disable=too-many-arguments
79
79
self .out .write (" pass\n \n " )
80
80
return
81
81
82
+ safe_inits = ["self" ] # type: List[Text]
83
+ safe_inits .extend ([self .safe_name (f ) for f in field_names if f != "class" ])
84
+ inits_types = ", " .join (["Any" ]* (len (safe_inits ) - 1 ))
82
85
self .out .write (
83
- """ def __init__(self, _doc, baseuri, loadingOptions, docRoot=None):
84
- # type: (Any, Text, LoadingOptions, Optional[Text]) -> None
85
- doc = copy.copy(_doc)
86
- if hasattr(_doc, 'lc'):
87
- doc.lc.data = _doc.lc.data
88
- doc.lc.filename = _doc.lc.filename
89
- errors = []
90
- self.loadingOptions = loadingOptions
86
+ " def __init__(" + ", " .join (safe_inits ) + ", extension_fields=None, loadingOptions=None):\n "
87
+ " # type: (" + inits_types + """, Optional[Dict[Text, Any]], Optional[LoadingOptions]) -> None
88
+ if extension_fields:
89
+ self.extension_fields = extension_fields
90
+ else:
91
+ self.extension_fields = yaml.comments.CommentedMap()
92
+ if loadingOptions:
93
+ self.loadingOptions = loadingOptions
94
+ else:
95
+ self.loadingOptions = LoadingOptions()
91
96
""" )
97
+ field_inits = ""
98
+ for name in field_names :
99
+ if name == "class" :
100
+ field_inits += """ self.class_ = "{}"
101
+ """ .format (classname )
102
+ else :
103
+ field_inits += """ self.{0} = {0}
104
+ """ .format (self .safe_name (name ))
105
+ self .out .write (field_inits + '\n '
106
+ + """
107
+ @classmethod
108
+ def fromDoc(cls, doc, baseuri, loadingOptions, docRoot=None):
109
+ # type: (Any, Text, LoadingOptions, Optional[Text]) -> {}
110
+
111
+ _doc = copy.copy(doc)
112
+ if hasattr(doc, 'lc'):
113
+ _doc.lc.data = doc.lc.data
114
+ _doc.lc.filename = doc.lc.filename
115
+ errors = []
116
+ """ .format (classname ))
92
117
93
118
self .idfield = idfield
94
119
95
120
self .serializer .write ("""
96
121
def save(self, top=False, base_url="", relative_uris=True):
97
122
# type: (bool, Text, bool) -> Dict[Text, Any]
98
- r = {} # type: Dict[Text, Any]
123
+ r = yaml.comments.CommentedMap() # type: Dict[Text, Any]
99
124
for ef in self.extension_fields:
100
125
r[prefix_url(ef, self.loadingOptions.vocab)] = self.extension_fields[ef]
101
126
""" )
102
127
103
128
if "class" in field_names :
104
129
self .out .write ("""
105
- if doc .get('class') != '{class_}':
130
+ if _doc .get('class') != '{class_}':
106
131
raise ValidationException("Not a {class_}")
107
132
108
133
""" .format (class_ = classname ))
@@ -119,14 +144,14 @@ def end_class(self, classname, field_names):
119
144
return
120
145
121
146
self .out .write ("""
122
- self. extension_fields = {{}} # type: Dict[Text, Text]
123
- for k in doc .keys():
124
- if k not in self .attrs:
147
+ extension_fields = yaml.comments.CommentedMap()
148
+ for k in _doc .keys():
149
+ if k not in cls .attrs:
125
150
if ":" in k:
126
151
ex = expand_url(k, u"", loadingOptions, scoped_id=False, vocab_term=False)
127
- self. extension_fields[ex] = doc [k]
152
+ extension_fields[ex] = _doc [k]
128
153
else:
129
- errors.append(SourceLine(doc , k, str).makeError("invalid field `%s`, expected one of: {attrstr}" % (k)))
154
+ errors.append(SourceLine(_doc , k, str).makeError("invalid field `%s`, expected one of: {attrstr}" % (k)))
130
155
break
131
156
132
157
if errors:
@@ -145,7 +170,17 @@ def end_class(self, classname, field_names):
145
170
146
171
self .serializer .write (" attrs = frozenset({attrs})\n " .format (attrs = field_names ))
147
172
173
+ safe_inits = [ self .safe_name (f ) for f in field_names if f != "class" ] # type: List[Text]
174
+
175
+ safe_inits .extend (["extension_fields=extension_fields" , "loadingOptions=loadingOptions" ])
176
+
177
+ self .out .write (""" loadingOptions = copy.deepcopy(loadingOptions)
178
+ loadingOptions.original_doc = _doc
179
+ """ )
180
+ self .out .write (" return cls(" + ", " .join (safe_inits )+ ")\n " )
181
+
148
182
self .out .write (self .serializer .getvalue ())
183
+
149
184
self .out .write ("\n \n " )
150
185
151
186
prims = {
@@ -206,19 +241,19 @@ def declare_id_field(self, name, fieldtype, doc, optional):
206
241
self .declare_field (name , fieldtype , doc , True )
207
242
208
243
if optional :
209
- opt = """self. {safename} = "_:" + str(uuid.uuid4())""" .format (
244
+ opt = """{safename} = "_:" + str(uuid.uuid4())""" .format (
210
245
safename = self .safe_name (name ))
211
246
else :
212
247
opt = """raise ValidationException("Missing {fieldname}")""" .format (
213
248
fieldname = shortname (name ))
214
249
215
250
self .out .write ("""
216
- if self. {safename} is None:
251
+ if {safename} is None:
217
252
if docRoot is not None:
218
- self. {safename} = docRoot
253
+ {safename} = docRoot
219
254
else:
220
255
{opt}
221
- baseuri = self. {safename}
256
+ baseuri = {safename}
222
257
""" .
223
258
format (safename = self .safe_name (name ),
224
259
fieldname = shortname (name ),
@@ -235,30 +270,30 @@ def declare_field(self, name, fieldtype, doc, optional):
235
270
return
236
271
237
272
if optional :
238
- self .out .write (" if '{fieldname}' in doc :\n " .format (fieldname = shortname (name )))
273
+ self .out .write (" if '{fieldname}' in _doc :\n " .format (fieldname = shortname (name )))
239
274
spc = " "
240
275
else :
241
276
spc = ""
242
277
self .out .write ("""{spc} try:
243
- {spc} self. {safename} = load_field(doc .get('{fieldname}'), {fieldtype}, baseuri, loadingOptions)
278
+ {spc} {safename} = load_field(_doc .get('{fieldname}'), {fieldtype}, baseuri, loadingOptions)
244
279
{spc} except ValidationException as e:
245
- {spc} errors.append(SourceLine(doc , '{fieldname}', str).makeError(\" the `{fieldname}` field is not valid because:\\ n\" +str(e)))
280
+ {spc} errors.append(SourceLine(_doc , '{fieldname}', str).makeError(\" the `{fieldname}` field is not valid because:\\ n\" +str(e)))
246
281
""" .
247
282
format (safename = self .safe_name (name ),
248
283
fieldname = shortname (name ),
249
284
fieldtype = fieldtype .name ,
250
285
spc = spc ))
251
286
if optional :
252
287
self .out .write (""" else:
253
- self. {safename} = None
288
+ {safename} = None
254
289
""" .format (safename = self .safe_name (name )))
255
290
256
291
self .out .write ("\n " )
257
292
258
293
if name == self .idfield or not self .idfield :
259
294
baseurl = 'base_url'
260
295
else :
261
- baseurl = " self.%s" % self .safe_name (self .idfield )
296
+ baseurl = ' self.{}' . format ( self .safe_name (self .idfield ) )
262
297
263
298
if fieldtype .is_uri :
264
299
self .serializer .write ("""
0 commit comments