Skip to content

Commit 298be83

Browse files
♻️ Make Schema converter object per schema. Fix generating double schemas.
1 parent 11bac24 commit 298be83

File tree

12 files changed

+288
-254
lines changed

12 files changed

+288
-254
lines changed

src/lapidary/render/model/conv_openapi.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import itertools
22
import logging
3+
from collections import defaultdict
34
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
4-
from typing import Any, cast
5+
from typing import Any
56

67
from mimeparse import parse_media_range
78

89
from .. import json_pointer, names
9-
from . import openapi, python
10+
from . import metamodel, openapi, python
1011
from .conv_schema import OpenApi30SchemaConverter
11-
from .metamodel import resolve_type_name
12+
from .metamodel import MetaModel, resolve_type_name
1213
from .python import type_hint
1314
from .refs import resolve_ref
1415
from .stack import Stack
@@ -22,7 +23,6 @@ def __init__(
2223
root_package: python.ModulePath,
2324
source: openapi.OpenAPI,
2425
origin: str | None,
25-
schema_converter: OpenApi30SchemaConverter | None = None,
2626
path_progress: Callable[[Any], None] | None = None,
2727
):
2828
self.root_package = root_package
@@ -32,8 +32,6 @@ def __init__(
3232
self._origin = origin
3333
self._path_progress = path_progress
3434

35-
self.schema_converter = schema_converter or OpenApi30SchemaConverter(self.root_package, source)
36-
3735
self.target = python.ClientModel(
3836
client=python.ClientModule(
3937
path=python.ModulePath((str(self.root_package), 'client')),
@@ -44,6 +42,12 @@ def __init__(
4442

4543
self._response_cache: MutableMapping[Stack, python.Response] = {}
4644

45+
self._models: MutableMapping[Stack, metamodel.MetaModel] = {}
46+
"""
47+
Store all models directly referred by methods.
48+
Indirectly referred must be accessible via the direct models.
49+
"""
50+
4751
def process(self) -> python.ClientModel:
4852
stack = Stack()
4953

@@ -59,9 +63,37 @@ def process(self) -> python.ClientModel:
5963
},
6064
)
6165

62-
self.target.model_modules.extend(self.schema_converter.schema_modules)
66+
models: MutableMapping[Stack, python.SchemaClass] = {}
67+
for model in self._models.values():
68+
self._collect_schema_models(model, models)
69+
70+
modules: Mapping[python.ModulePath, list[python.SchemaClass]] = defaultdict(list)
71+
for stack, class_ in models.items():
72+
modules[python.ModulePath(resolve_type_name(str(self.root_package), stack).typ.module)].append(class_)
73+
74+
self.target.model_modules.extend(
75+
(
76+
python.SchemaModule(
77+
path=module_path,
78+
body=models,
79+
)
80+
for module_path, models in modules.items()
81+
)
82+
)
83+
6384
return self.target
6485

86+
def _collect_schema_models(
87+
self, model: metamodel.MetaModel, models: MutableMapping[Stack, python.SchemaClass]
88+
) -> None:
89+
try:
90+
if class_ := model.as_type(str(self.root_package)):
91+
models[model.stack] = class_
92+
except Exception:
93+
raise
94+
for submodel in model.dependencies():
95+
self._collect_schema_models(submodel, models)
96+
6597
def process_servers(self, value: list[openapi.Server] | None, stack: Stack) -> None:
6698
logger.debug('Process servers %s', stack)
6799

@@ -108,13 +140,13 @@ def _process_schema_or_content(
108140
if value.param_schema and value.content:
109141
raise ValueError()
110142
if value.param_schema:
111-
model = self.schema_converter.process_type_schema(value.param_schema, stack.push('schema'))
143+
model = self._process_schema(value.param_schema, stack.push('schema'))
112144
assert model
113145
return model.as_annotation(str(self.root_package), value.required), None
114146
elif value.content:
115147
media_type, media_type_obj = next(iter(value.content.items()))
116148
# encoding = media_type_obj.encoding
117-
model = self.schema_converter.process_type_schema(
149+
model = self._process_schema(
118150
media_type_obj.media_type_schema or openapi.Schema(), stack.push('content', media_type)
119151
)
120152
assert model
@@ -236,14 +268,20 @@ def process_content(self, value: Mapping[str, openapi.MediaType], stack: Stack)
236268
mime_parsed = parse_media_range(mime)
237269
if mime_parsed[:2] != ('application', 'json'):
238270
continue
239-
model = self.schema_converter.process_type_schema(
240-
media_type.media_type_schema or openapi.Schema(),
241-
stack.push(mime, 'schema'),
242-
)
271+
model = self._process_schema(media_type.media_type_schema or openapi.Schema(), stack.push(mime, 'schema'))
243272
assert model
244273
types[mime] = model.as_annotation(str(self.root_package))
245274
return types
246275

276+
@resolve_ref
277+
def _process_schema(self, value: openapi.Schema, stack: Stack) -> MetaModel | None:
278+
if not (model := self._models.get(stack)):
279+
converter = OpenApi30SchemaConverter(value, stack, self.root_package, self.source)
280+
if (model := converter.process_schema()) is not None:
281+
self._models[stack] = model
282+
283+
return model
284+
247285
def process_operation(
248286
self,
249287
value: openapi.Operation,
@@ -355,7 +393,9 @@ def process_security_requirement(
355393
schemes_root = Stack(('#', 'components', 'securitySchemes'))
356394
for scheme_name, scopes in value.items():
357395
scheme_stack = schemes_root.push(scheme_name)
358-
self.process_security_scheme(openapi.Reference(ref=str(scheme_stack)), scheme_stack)
396+
self.process_security_scheme(
397+
openapi.Reference[openapi.SecurityRequirement](ref=str(scheme_stack)), scheme_stack
398+
)
359399
return value
360400

361401
# need separate method to resolve references before calling a single-dispatched method
@@ -467,11 +507,6 @@ def process_security_scheme_http(self, value: openapi.SecurityScheme, stack: Sta
467507
except KeyError:
468508
raise NotImplementedError(stack.push('scheme'), value.scheme) from None
469509

470-
def resolve_ref[Target](self, ref: openapi.Reference) -> tuple[Target, Stack]:
471-
"""Resolve reference to OpenAPI object and its direct path."""
472-
value, pointer = self.source.resolve_ref(ref)
473-
return cast(Target, value), Stack.from_str(pointer)
474-
475510

476511
def param_style(
477512
style: str | None,

0 commit comments

Comments
 (0)