diff --git a/openapi_core/schema/schemas.py b/openapi_core/schema/schemas.py index 9cdc2e92..977e426b 100644 --- a/openapi_core/schema/schemas.py +++ b/openapi_core/schema/schemas.py @@ -4,15 +4,7 @@ from openapi_core.spec import Spec -def get_all_properties(schema: Spec) -> Dict[str, Any]: +def get_properties(schema: Spec) -> Dict[str, Any]: properties = schema.get("properties", {}) properties_dict = dict(list(properties.items())) - - if "allOf" not in schema: - return properties_dict - - for subschema in schema / "allOf": - subschema_props = get_all_properties(subschema) - properties_dict.update(subschema_props) - return properties_dict diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 41e3e3aa..9440cae9 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -6,6 +6,7 @@ from typing import Type from typing import Union +from backports.cached_property import cached_property from jsonschema.protocols import Validator from openapi_schema_validator import OAS30Validator @@ -41,6 +42,42 @@ from openapi_core.unmarshalling.schemas.util import build_format_checker +class SchemaValidatorsFactory: + + CONTEXTS = { + UnmarshalContext.REQUEST: "write", + UnmarshalContext.RESPONSE: "read", + } + + def __init__( + self, + schema_validator_class: Type[Validator], + custom_formatters: Optional[CustomFormattersDict] = None, + context: Optional[UnmarshalContext] = None, + ): + self.schema_validator_class = schema_validator_class + if custom_formatters is None: + custom_formatters = {} + self.custom_formatters = custom_formatters + self.context = context + + def create(self, schema: Spec) -> Validator: + resolver = schema.accessor.resolver # type: ignore + custom_format_checks = { + name: formatter.validate + for name, formatter in self.custom_formatters.items() + } + format_checker = build_format_checker(**custom_format_checks) + kwargs = { + "resolver": resolver, + "format_checker": format_checker, + } + if self.context is not None: + kwargs[self.CONTEXTS[self.context]] = True + with schema.open() as schema_dict: + return self.schema_validator_class(schema_dict, **kwargs) + + class SchemaUnmarshallersFactory: UNMARSHALLERS: Dict[str, Type[BaseSchemaUnmarshaller]] = { @@ -60,11 +97,6 @@ class SchemaUnmarshallersFactory: "any": AnyUnmarshaller, } - CONTEXT_VALIDATION = { - UnmarshalContext.REQUEST: "write", - UnmarshalContext.RESPONSE: "read", - } - def __init__( self, schema_validator_class: Type[Validator], @@ -77,6 +109,14 @@ def __init__( self.custom_formatters = custom_formatters self.context = context + @cached_property + def validators_factory(self) -> SchemaValidatorsFactory: + return SchemaValidatorsFactory( + self.schema_validator_class, + self.custom_formatters, + self.context, + ) + def create( self, schema: Spec, type_override: Optional[str] = None ) -> BaseSchemaUnmarshaller: @@ -87,7 +127,7 @@ def create( if schema.getkey("deprecated", False): warnings.warn("The schema is deprecated", DeprecationWarning) - validator = self.get_validator(schema) + validator = self.validators_factory.create(schema) schema_format = schema.getkey("format") formatter = self.custom_formatters.get(schema_format) @@ -97,29 +137,29 @@ def create( schema_type, str ): return MultiTypeUnmarshaller( - schema, validator, formatter, self, context=self.context + schema, + validator, + formatter, + self.validators_factory, + self, + context=self.context, ) if schema_type in self.COMPLEX_UNMARSHALLERS: complex_klass = self.COMPLEX_UNMARSHALLERS[schema_type] return complex_klass( - schema, validator, formatter, self, context=self.context + schema, + validator, + formatter, + self.validators_factory, + self, + context=self.context, ) klass = self.UNMARSHALLERS[schema_type] - return klass(schema, validator, formatter) - - def get_validator(self, schema: Spec) -> Validator: - resolver = schema.accessor.resolver # type: ignore - custom_format_checks = { - name: formatter.validate - for name, formatter in self.custom_formatters.items() - } - format_checker = build_format_checker(**custom_format_checks) - kwargs = { - "resolver": resolver, - "format_checker": format_checker, - } - if self.context is not None: - kwargs[self.CONTEXT_VALIDATION[self.context]] = True - with schema.open() as schema_dict: - return self.schema_validator_class(schema_dict, **kwargs) + return klass( + schema, + validator, + formatter, + self.validators_factory, + self, + ) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 9329dc78..94baea8c 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -15,12 +15,13 @@ from jsonschema._types import is_null from jsonschema._types import is_number from jsonschema._types import is_object +from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from openapi_schema_validator._format import oas30_format_checker from openapi_schema_validator._types import is_string from openapi_core.extensions.models.factories import ModelPathFactory -from openapi_core.schema.schemas import get_all_properties +from openapi_core.schema.schemas import get_properties from openapi_core.spec import Spec from openapi_core.unmarshalling.schemas.datatypes import FormattersDict from openapi_core.unmarshalling.schemas.enums import UnmarshalContext @@ -45,6 +46,9 @@ from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) + from openapi_core.unmarshalling.schemas.factories import ( + SchemaValidatorsFactory, + ) log = logging.getLogger(__name__) @@ -60,6 +64,8 @@ def __init__( schema: Spec, validator: Validator, formatter: Optional[Formatter], + validators_factory: "SchemaValidatorsFactory", + unmarshallers_factory: "SchemaUnmarshallersFactory", ): self.schema = schema self.validator = validator @@ -72,6 +78,9 @@ def __init__( else: self.formatter = formatter + self.validators_factory = validators_factory + self.unmarshallers_factory = unmarshallers_factory + def __call__(self, value: Any) -> Any: self.validate(value) @@ -100,8 +109,92 @@ def format(self, value: Any) -> Any: except (ValueError, TypeError) as exc: raise InvalidSchemaFormatValue(value, self.schema_format, exc) + def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller": + if "format" not in self.schema: + one_of_schema = self._get_one_of_schema(value) + if one_of_schema is not None and "format" in one_of_schema: + one_of_unmarshaller = self.unmarshallers_factory.create( + one_of_schema + ) + return one_of_unmarshaller + + any_of_schemas = self._iter_any_of_schemas(value) + for any_of_schema in any_of_schemas: + if "format" in any_of_schema: + any_of_unmarshaller = self.unmarshallers_factory.create( + any_of_schema + ) + return any_of_unmarshaller + + all_of_schemas = self._iter_all_of_schemas(value) + for all_of_schema in all_of_schemas: + if "format" in all_of_schema: + all_of_unmarshaller = self.unmarshallers_factory.create( + all_of_schema + ) + return all_of_unmarshaller + + return self + def unmarshal(self, value: Any) -> Any: - return self.format(value) + unmarshaller = self._get_best_unmarshaller(value) + return unmarshaller.format(value) + + def _get_one_of_schema( + self, + value: Any, + ) -> Optional[Spec]: + if "oneOf" not in self.schema: + return None + + one_of_schemas = self.schema / "oneOf" + for subschema in one_of_schemas: + validator = self.validators_factory.create(subschema) + try: + validator.validate(value) + except ValidationError: + continue + else: + return subschema + + log.warning("valid oneOf schema not found") + return None + + def _iter_any_of_schemas( + self, + value: Any, + ) -> Iterator[Spec]: + if "anyOf" not in self.schema: + return + + any_of_schemas = self.schema / "anyOf" + for subschema in any_of_schemas: + validator = self.validators_factory.create(subschema) + try: + validator.validate(value) + except ValidationError: + continue + else: + yield subschema + + def _iter_all_of_schemas( + self, + value: Any, + ) -> Iterator[Spec]: + if "allOf" not in self.schema: + return + + all_of_schemas = self.schema / "allOf" + for subschema in all_of_schemas: + if "type" not in subschema: + continue + validator = self.validators_factory.create(subschema) + try: + validator.validate(value) + except ValidationError: + log.warning("invalid allOf schema found") + else: + yield subschema class StringUnmarshaller(BaseSchemaUnmarshaller): @@ -178,11 +271,17 @@ def __init__( schema: Spec, validator: Validator, formatter: Optional[Formatter], + validators_factory: "SchemaValidatorsFactory", unmarshallers_factory: "SchemaUnmarshallersFactory", context: Optional[UnmarshalContext] = None, ): - super().__init__(schema, validator, formatter) - self.unmarshallers_factory = unmarshallers_factory + super().__init__( + schema, + validator, + formatter, + validators_factory, + unmarshallers_factory, + ) self.context = context @@ -221,70 +320,62 @@ def unmarshal(self, value: Any) -> Any: return object_class(**properties) - def format(self, value: Any) -> Any: + def format(self, value: Any, schema_only: bool = False) -> Any: formatted = super().format(value) - return self._unmarshal_properties(formatted) + return self._unmarshal_properties(formatted, schema_only=schema_only) def _clone(self, schema: Spec) -> "ObjectUnmarshaller": return cast( "ObjectUnmarshaller", - self.unmarshallers_factory.create(schema, "object"), + self.unmarshallers_factory.create(schema, type_override="object"), ) - def _unmarshal_properties(self, value: Any) -> Any: + def _unmarshal_properties( + self, value: Any, schema_only: bool = False + ) -> Any: properties = {} - if "oneOf" in self.schema: - one_of_properties = None - for one_of_schema in self.schema / "oneOf": - try: - unmarshalled = self._clone(one_of_schema).format(value) - except (UnmarshalError, ValueError): - pass - else: - if one_of_properties is not None: - log.warning("multiple valid oneOf schemas found") - continue - one_of_properties = unmarshalled - - if one_of_properties is None: - log.warning("valid oneOf schema not found") - else: - properties.update(one_of_properties) - - elif "anyOf" in self.schema: - any_of_properties = None - for any_of_schema in self.schema / "anyOf": - try: - unmarshalled = self._clone(any_of_schema).format(value) - except (UnmarshalError, ValueError): - pass - else: - any_of_properties = unmarshalled - break - - if any_of_properties is None: - log.warning("valid anyOf schema not found") - else: - properties.update(any_of_properties) + one_of_schema = self._get_one_of_schema(value) + if one_of_schema is not None: + one_of_properties = self._clone(one_of_schema).format( + value, schema_only=True + ) + properties.update(one_of_properties) + + any_of_schemas = self._iter_any_of_schemas(value) + for any_of_schema in any_of_schemas: + any_of_properties = self._clone(any_of_schema).format( + value, schema_only=True + ) + properties.update(any_of_properties) - for prop_name, prop in get_all_properties(self.schema).items(): - read_only = prop.getkey("readOnly", False) + all_of_schemas = self._iter_all_of_schemas(value) + for all_of_schema in all_of_schemas: + all_of_properties = self._clone(all_of_schema).format( + value, schema_only=True + ) + properties.update(all_of_properties) + + for prop_name, prop_schema in get_properties(self.schema).items(): + read_only = prop_schema.getkey("readOnly", False) if self.context == UnmarshalContext.REQUEST and read_only: continue - write_only = prop.getkey("writeOnly", False) + write_only = prop_schema.getkey("writeOnly", False) if self.context == UnmarshalContext.RESPONSE and write_only: continue try: prop_value = value[prop_name] except KeyError: - if "default" not in prop: + if "default" not in prop_schema: continue - prop_value = prop["default"] + prop_value = prop_schema["default"] - properties[prop_name] = self.unmarshallers_factory.create(prop)( - prop_value - ) + properties[prop_name] = self.unmarshallers_factory.create( + prop_schema + )(prop_value) + + if schema_only: + return properties additional_properties = self.schema.getkey( "additionalProperties", True @@ -340,7 +431,7 @@ def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller": def unmarshal(self, value: Any) -> Any: unmarshaller = self._get_best_unmarshaller(value) - return unmarshaller(value) + return unmarshaller.unmarshal(value) class AnyUnmarshaller(MultiTypeUnmarshaller): @@ -357,65 +448,3 @@ class AnyUnmarshaller(MultiTypeUnmarshaller): @property def type(self) -> List[str]: return self.SCHEMA_TYPES_ORDER - - def unmarshal(self, value: Any) -> Any: - one_of_schema = self._get_one_of_schema(value) - if one_of_schema: - return self.unmarshallers_factory.create(one_of_schema)(value) - - any_of_schema = self._get_any_of_schema(value) - if any_of_schema: - return self.unmarshallers_factory.create(any_of_schema)(value) - - all_of_schema = self._get_all_of_schema(value) - if all_of_schema: - return self.unmarshallers_factory.create(all_of_schema)(value) - - return super().unmarshal(value) - - def _get_one_of_schema(self, value: Any) -> Optional[Spec]: - if "oneOf" not in self.schema: - return None - - one_of_schemas = self.schema / "oneOf" - for subschema in one_of_schemas: - unmarshaller = self.unmarshallers_factory.create(subschema) - try: - unmarshaller.validate(value) - except ValidateError: - continue - else: - return subschema - return None - - def _get_any_of_schema(self, value: Any) -> Optional[Spec]: - if "anyOf" not in self.schema: - return None - - any_of_schemas = self.schema / "anyOf" - for subschema in any_of_schemas: - unmarshaller = self.unmarshallers_factory.create(subschema) - try: - unmarshaller.validate(value) - except ValidateError: - continue - else: - return subschema - return None - - def _get_all_of_schema(self, value: Any) -> Optional[Spec]: - if "allOf" not in self.schema: - return None - - all_of_schemas = self.schema / "allOf" - for subschema in all_of_schemas: - if "type" not in subschema: - continue - unmarshaller = self.unmarshallers_factory.create(subschema) - try: - unmarshaller.validate(value) - except ValidateError: - continue - else: - return subschema - return None diff --git a/poetry.lock b/poetry.lock index c5f2d2d5..4b9cbe99 100644 --- a/poetry.lock +++ b/poetry.lock @@ -71,6 +71,14 @@ python-versions = ">=3.6" [package.dependencies] pytz = ">=2015.7" +[[package]] +name = "backports-cached-property" +version = "1.0.2" +description = "cached_property() - computed once per instance, cached as attribute" +category = "main" +optional = false +python-versions = ">=3.6.0" + [[package]] name = "black" version = "22.10.0" @@ -1117,7 +1125,7 @@ starlette = [] [metadata] lock-version = "1.1" python-versions = "^3.7.0" -content-hash = "25d23ad11b888728528627234a4d5f017d744c9a96e2a1a953a6129595464e9e" +content-hash = "49bdb4e150245eb8dec5b3c7a4de8473e9beb8f8790d7d8af454d526ffae538d" [metadata.files] alabaster = [ @@ -1143,6 +1151,10 @@ babel = [ {file = "Babel-2.11.0-py3-none-any.whl", hash = "sha256:1ad3eca1c885218f6dce2ab67291178944f810a10a9b5f3cb8382a5a232b64fe"}, {file = "Babel-2.11.0.tar.gz", hash = "sha256:5ef4b3226b0180dedded4229651c8b0e1a3a6a2837d45a073272f313e4cf97f6"}, ] +backports-cached-property = [ + {file = "backports.cached-property-1.0.2.tar.gz", hash = "sha256:9306f9eed6ec55fd156ace6bc1094e2c86fae5fb2bf07b6a9c00745c656e75dd"}, + {file = "backports.cached_property-1.0.2-py3-none-any.whl", hash = "sha256:baeb28e1cd619a3c9ab8941431fe34e8490861fb998c6c4590693d50171db0cc"}, +] black = [ {file = "black-22.10.0-1fixedarch-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:5cc42ca67989e9c3cf859e84c2bf014f6633db63d1cbdf8fdb666dcd9e77e3fa"}, {file = "black-22.10.0-1fixedarch-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:5d8f74030e67087b219b032aa33a919fae8806d49c867846bfacde57f43972ef"}, diff --git a/pyproject.toml b/pyproject.toml index 2014fb4b..09a24ac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ requests = {version = "*", optional = true} werkzeug = "*" typing-extensions = "^4.3.0" jsonschema-spec = "^0.1.1" +backports-cached-property = "^1.0.2" [tool.poetry.extras] django = ["django"] diff --git a/tests/integration/contrib/django/data/v3.0/djangoproject/pets/views.py b/tests/integration/contrib/django/data/v3.0/djangoproject/pets/views.py index f46cb2eb..8e4b38fd 100644 --- a/tests/integration/contrib/django/data/v3.0/djangoproject/pets/views.py +++ b/tests/integration/contrib/django/data/v3.0/djangoproject/pets/views.py @@ -39,9 +39,13 @@ def post(self, request): "api-key": "12345", } assert request.openapi.body.__class__.__name__ == "PetCreate" - assert request.openapi.body.name == "Cat" - assert request.openapi.body.ears.__class__.__name__ == "Ears" - assert request.openapi.body.ears.healthy is True + assert request.openapi.body.name in ["Cat", "Bird"] + if request.openapi.body.name == "Cat": + assert request.openapi.body.ears.__class__.__name__ == "Ears" + assert request.openapi.body.ears.healthy is True + if request.openapi.body.name == "Bird": + assert request.openapi.body.wings.__class__.__name__ == "Wings" + assert request.openapi.body.wings.healthy is True django_response = HttpResponse(status=201) django_response["X-Rate-Limit"] = "12" diff --git a/tests/integration/contrib/django/test_django_project.py b/tests/integration/contrib/django/test_django_project.py index faf64387..1f394fe4 100644 --- a/tests/integration/contrib/django/test_django_project.py +++ b/tests/integration/contrib/django/test_django_project.py @@ -225,16 +225,28 @@ def test_post_required_cookie_param_missing(self, client): assert response.status_code == 400 assert response.json() == expected_data - def test_post_valid(self, client): + @pytest.mark.parametrize( + "data_json", + [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + }, + { + "id": 12, + "name": "Bird", + "wings": { + "healthy": True, + }, + }, + ], + ) + def test_post_valid(self, client, data_json): client.cookies.load({"user": 1}) content_type = "application/json" - data_json = { - "id": 12, - "name": "Cat", - "ears": { - "healthy": True, - }, - } headers = { "HTTP_AUTHORIZATION": "Basic testuser", "HTTP_HOST": "staging.gigantic-server.com", diff --git a/tests/integration/contrib/falcon/data/v3.0/falconproject/pets/resources.py b/tests/integration/contrib/falcon/data/v3.0/falconproject/pets/resources.py index 154d50ff..ff22b599 100644 --- a/tests/integration/contrib/falcon/data/v3.0/falconproject/pets/resources.py +++ b/tests/integration/contrib/falcon/data/v3.0/falconproject/pets/resources.py @@ -38,9 +38,18 @@ def on_post(self, request, response): "api-key": "12345", } assert request.context.openapi.body.__class__.__name__ == "PetCreate" - assert request.context.openapi.body.name == "Cat" - assert request.context.openapi.body.ears.__class__.__name__ == "Ears" - assert request.context.openapi.body.ears.healthy is True + assert request.context.openapi.body.name in ["Cat", "Bird"] + if request.context.openapi.body.name == "Cat": + assert ( + request.context.openapi.body.ears.__class__.__name__ == "Ears" + ) + assert request.context.openapi.body.ears.healthy is True + if request.context.openapi.body.name == "Bird": + assert ( + request.context.openapi.body.wings.__class__.__name__ + == "Wings" + ) + assert request.context.openapi.body.wings.healthy is True response.status = HTTP_201 response.set_header("X-Rate-Limit", "12") diff --git a/tests/integration/contrib/falcon/test_falcon_project.py b/tests/integration/contrib/falcon/test_falcon_project.py index 547fda0f..866a3e99 100644 --- a/tests/integration/contrib/falcon/test_falcon_project.py +++ b/tests/integration/contrib/falcon/test_falcon_project.py @@ -210,16 +210,28 @@ def test_post_required_cookie_param_missing(self, client): assert response.status_code == 400 assert response.json == expected_data - def test_post_valid(self, client): + @pytest.mark.parametrize( + "data_json", + [ + { + "id": 12, + "name": "Cat", + "ears": { + "healthy": True, + }, + }, + { + "id": 12, + "name": "Bird", + "wings": { + "healthy": True, + }, + }, + ], + ) + def test_post_valid(self, client, data_json): cookies = {"user": 1} content_type = "application/json" - data_json = { - "id": 12, - "name": "Cat", - "ears": { - "healthy": True, - }, - } headers = { "Authorization": "Basic testuser", "Api-Key": self.api_key_encoded, diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index b1647556..5e2be72a 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -461,6 +461,7 @@ components: content: application/json: schema: + x-model: Error oneOf: - $ref: "#/components/schemas/StandardErrors" - $ref: "#/components/schemas/ExtendedError" diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index e0ef91c3..d0e2018b 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -630,6 +630,100 @@ def test_schema_object_any_of_invalid(self, unmarshaller_factory): with pytest.raises(UnmarshalError): unmarshaller_factory(spec)({"someint": "1"}) + def test_schema_object_one_of_default(self, unmarshaller_factory): + schema = { + "type": "object", + "oneOf": [ + { + "type": "object", + "properties": { + "somestr": { + "type": "string", + "default": "defaultstring", + }, + }, + }, + { + "type": "object", + "required": ["otherstr"], + "properties": { + "otherstr": { + "type": "string", + }, + }, + }, + ], + "properties": { + "someint": { + "type": "integer", + }, + }, + } + spec = Spec.from_dict(schema) + assert unmarshaller_factory(spec)({"someint": 1}) == { + "someint": 1, + "somestr": "defaultstring", + } + + def test_schema_object_any_of_default(self, unmarshaller_factory): + schema = { + "type": "object", + "anyOf": [ + { + "type": "object", + "properties": { + "someint": { + "type": "integer", + }, + }, + }, + { + "type": "object", + "properties": { + "somestr": { + "type": "string", + "default": "defaultstring", + }, + }, + }, + ], + } + spec = Spec.from_dict(schema) + assert unmarshaller_factory(spec)({"someint": "1"}) == { + "someint": "1", + "somestr": "defaultstring", + } + + def test_schema_object_all_of_default(self, unmarshaller_factory): + schema = { + "type": "object", + "allOf": [ + { + "type": "object", + "properties": { + "somestr": { + "type": "string", + "default": "defaultstring", + }, + }, + }, + { + "type": "object", + "properties": { + "someint": { + "type": "integer", + "default": 1, + }, + }, + }, + ], + } + spec = Spec.from_dict(schema) + assert unmarshaller_factory(spec)({}) == { + "someint": 1, + "somestr": "defaultstring", + } + def test_schema_any_all_of(self, unmarshaller_factory): schema = { "allOf": [ @@ -697,6 +791,23 @@ def test_schema_any_all_of_invalid_properties( with pytest.raises(InvalidSchemaValue): unmarshaller_factory(spec)(value) + def test_schema_any_any_of_any(self, unmarshaller_factory): + schema = { + "anyOf": [ + {}, + { + "type": "string", + "format": "date", + }, + ], + } + spec = Spec.from_dict(schema) + value = "2018-01-02" + + result = unmarshaller_factory(spec)(value) + + assert result == datetime.date(2018, 1, 2) + def test_schema_any_all_of_any(self, unmarshaller_factory): schema = { "allOf": [