Skip to content

Commit 4d78043

Browse files
committed
x-model extension import model class
1 parent 1aece83 commit 4d78043

File tree

5 files changed

+75
-11
lines changed

5 files changed

+75
-11
lines changed

openapi_core/extensions/models/factories.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""OpenAPI X-Model extension factories module"""
2+
from pydoc import ErrorDuringImport
3+
from pydoc import locate
24
from typing import Any
35
from typing import Dict
46
from typing import Optional
@@ -22,12 +24,36 @@ def __init__(
2224
self.model_class_factory = model_class_factory or ModelClassFactory()
2325

2426
def create(
25-
self, properties: Optional[Dict[str, Any]], name: Optional[str] = None
27+
self,
28+
name: str,
29+
**properties: Any,
2630
) -> Model:
27-
name = name or "Model"
28-
2931
model_class = self._create_class(name)
30-
return model_class(properties)
32+
return model_class(**properties)
3133

3234
def _create_class(self, name: str) -> Type[Model]:
3335
return self.model_class_factory.create(name)
36+
37+
38+
class ModelImporter(ModelFactory):
39+
def __init__(
40+
self,
41+
model_class_factory: Optional[ModelClassFactory] = None,
42+
models: Optional[Dict[str, Any]] = None,
43+
):
44+
super().__init__(model_class_factory)
45+
self.registry = models or {}
46+
47+
def create(self, name: str, **properties: Any):
48+
model_class = self._get_class(name)
49+
50+
if model_class is None:
51+
model_class = self._create_class(name)
52+
53+
return model_class(**properties)
54+
55+
def _get_class(self, name: str):
56+
try:
57+
return locate(name)
58+
except ErrorDuringImport:
59+
return None

openapi_core/extensions/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __dict__(self) -> Dict[Any, Any]: # type: ignore
1515
class Model(BaseModel):
1616
"""Model class for OpenAPI X-Model."""
1717

18-
def __init__(self, properties: Optional[Dict[str, Any]] = None):
18+
def __init__(self, **properties: Any):
1919
self.__properties = properties or {}
2020

2121
@property

openapi_core/unmarshalling/schemas/unmarshallers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from openapi_schema_validator._format import oas30_format_checker
1717
from openapi_schema_validator._types import is_string
1818

19-
from openapi_core.extensions.models.factories import ModelFactory
19+
from openapi_core.extensions.models.factories import ModelImporter
2020
from openapi_core.schema.schemas import get_all_properties
2121
from openapi_core.schema.schemas import get_all_properties_names
2222
from openapi_core.spec import Spec
@@ -196,8 +196,8 @@ class ObjectUnmarshaller(ComplexUnmarshaller):
196196
}
197197

198198
@property
199-
def model_factory(self) -> ModelFactory:
200-
return ModelFactory()
199+
def model_factory(self) -> ModelImporter:
200+
return ModelImporter()
201201

202202
def unmarshal(self, value: Any) -> Any:
203203
try:
@@ -232,7 +232,7 @@ def _unmarshal_object(self, value: Any) -> Any:
232232

233233
if "x-model" in self.schema:
234234
name = self.schema["x-model"]
235-
return self.model_factory.create(properties, name=name)
235+
return self.model_factory.create(name, **properties)
236236

237237
return properties
238238

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from dataclasses import dataclass
2+
from sys import modules
3+
from types import ModuleType
4+
5+
import pytest
6+
7+
from openapi_core.extensions.models.factories import ModelImporter
8+
from openapi_core.extensions.models.models import Model
9+
10+
11+
class TestImportModelCreate:
12+
@pytest.fixture
13+
def loaded_model_class(self):
14+
@dataclass
15+
class BarModel:
16+
a: str
17+
b: int
18+
19+
foo_module = ModuleType("foo")
20+
foo_module.BarModel = BarModel
21+
modules["foo"] = foo_module
22+
yield BarModel
23+
del modules["foo"]
24+
25+
def test_dynamic_model(self):
26+
factory = ModelImporter()
27+
28+
test_model = factory.create("TestModel")
29+
30+
assert test_model.__class__.__name__ == "TestModel"
31+
assert test_model.__class__.__bases__ == (Model,)
32+
33+
def test_imported_model(self, loaded_model_class):
34+
factory = ModelImporter()
35+
36+
test_model = factory.create("foo.BarModel", a="test", b=11)
37+
38+
assert test_model.__class__ == loaded_model_class

tests/unit/extensions/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_dict(self):
2525
"prop1": "value1",
2626
"prop2": "value2",
2727
}
28-
model = Model(properties)
28+
model = Model(**properties)
2929

3030
result = model.__dict__
3131

@@ -36,7 +36,7 @@ def test_attribute(self):
3636
properties = {
3737
"prop1": prop_value,
3838
}
39-
model = Model(properties)
39+
model = Model(**properties)
4040

4141
result = model.prop1
4242

0 commit comments

Comments
 (0)