diff --git a/src/graphql/pyutils/undefined.py b/src/graphql/pyutils/undefined.py index 8a078eba..a5ab96ec 100644 --- a/src/graphql/pyutils/undefined.py +++ b/src/graphql/pyutils/undefined.py @@ -1,4 +1,5 @@ -from typing import Any +import warnings +from typing import Any, Optional __all__ = ["Undefined", "UndefinedType"] @@ -7,6 +8,18 @@ class UndefinedType(ValueError): """Auxiliary class for creating the Undefined singleton.""" + _instance: Optional["UndefinedType"] = None + + def __new__(cls) -> "UndefinedType": + if cls._instance is None: + cls._instance = super().__new__(cls) + else: + warnings.warn("Redefinition of 'Undefined'", RuntimeWarning, stacklevel=2) + return cls._instance + + def __reduce__(self) -> str: + return "Undefined" + def __repr__(self) -> str: return "Undefined" diff --git a/src/graphql/type/definition.py b/src/graphql/type/definition.py index 0f5895ac..d9fe289d 100644 --- a/src/graphql/type/definition.py +++ b/src/graphql/type/definition.py @@ -231,6 +231,23 @@ class GraphQLNamedType(GraphQLType): ast_node: Optional[TypeDefinitionNode] extension_ast_nodes: Tuple[TypeExtensionNode, ...] + reserved_types: Dict[str, "GraphQLNamedType"] = {} + + def __new__(cls, name: str, *_args: Any, **_kwargs: Any) -> "GraphQLNamedType": + if name in cls.reserved_types: + raise TypeError(f"Redefinition of reserved type {name!r}") + return super().__new__(cls) + + def __reduce__(self) -> Tuple[Callable, Tuple]: + return self._get_instance, (self.name, tuple(self.to_kwargs().items())) + + @classmethod + def _get_instance(cls, name: str, args: Tuple) -> "GraphQLNamedType": + try: + return cls.reserved_types[name] + except KeyError: + return cls(**dict(args)) + def __init__( self, name: str, diff --git a/src/graphql/type/introspection.py b/src/graphql/type/introspection.py index 160c582c..17922d21 100644 --- a/src/graphql/type/introspection.py +++ b/src/graphql/type/introspection.py @@ -8,6 +8,7 @@ GraphQLEnumType, GraphQLEnumValue, GraphQLField, + GraphQLFieldMap, GraphQLList, GraphQLNamedType, GraphQLNonNull, @@ -35,88 +36,105 @@ ] -__Schema: GraphQLObjectType = GraphQLObjectType( +class SchemaFields(GraphQLFieldMap): + def __new__(cls): + return { + "description": GraphQLField(GraphQLString, resolve=cls.description), + "types": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(_Type))), + resolve=cls.types, + description="A list of all types supported by this server.", + ), + "queryType": GraphQLField( + GraphQLNonNull(_Type), + resolve=cls.query_type, + description="The type that query operations will be rooted at.", + ), + "mutationType": GraphQLField( + _Type, + resolve=cls.mutation_type, + description="If this server supports mutation, the type that" + " mutation operations will be rooted at.", + ), + "subscriptionType": GraphQLField( + _Type, + resolve=cls.subscription_type, + description="If this server supports subscription, the type that" + " subscription operations will be rooted at.", + ), + "directives": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(_Directive))), + resolve=cls.directives, + description="A list of all directives supported by this server.", + ), + } + + @staticmethod + def description(schema, _info): + return schema.description + + @staticmethod + def types(schema, _info): + return schema.type_map.values() + + @staticmethod + def query_type(schema, _info): + return schema.query_type + + @staticmethod + def mutation_type(schema, _info): + return schema.mutation_type + + @staticmethod + def subscription_type(schema, _info): + return schema.subscription_type + + @staticmethod + def directives(schema, _info): + return schema.directives + + +_Schema: GraphQLObjectType = GraphQLObjectType( name="__Schema", description="A GraphQL Schema defines the capabilities of a GraphQL" " server. It exposes all available types and directives" " on the server, as well as the entry points for query," " mutation, and subscription operations.", - fields=lambda: { - "description": GraphQLField( - GraphQLString, resolve=lambda schema, _info: schema.description - ), - "types": GraphQLField( - GraphQLNonNull(GraphQLList(GraphQLNonNull(__Type))), - resolve=lambda schema, _info: schema.type_map.values(), - description="A list of all types supported by this server.", - ), - "queryType": GraphQLField( - GraphQLNonNull(__Type), - resolve=lambda schema, _info: schema.query_type, - description="The type that query operations will be rooted at.", - ), - "mutationType": GraphQLField( - __Type, - resolve=lambda schema, _info: schema.mutation_type, - description="If this server supports mutation, the type that" - " mutation operations will be rooted at.", - ), - "subscriptionType": GraphQLField( - __Type, - resolve=lambda schema, _info: schema.subscription_type, - description="If this server support subscription, the type that" - " subscription operations will be rooted at.", - ), - "directives": GraphQLField( - GraphQLNonNull(GraphQLList(GraphQLNonNull(__Directive))), - resolve=lambda schema, _info: schema.directives, - description="A list of all directives supported by this server.", - ), - }, + fields=SchemaFields, ) -__Directive: GraphQLObjectType = GraphQLObjectType( - name="__Directive", - description="A Directive provides a way to describe alternate runtime" - " execution and type validation behavior in a GraphQL" - " document.\n\nIn some cases, you need to provide options" - " to alter GraphQL's execution behavior in ways field" - " arguments will not suffice, such as conditionally including" - " or skipping a field. Directives provide this by describing" - " additional information to the executor.", - fields=lambda: { - # Note: The fields onOperation, onFragment and onField are deprecated - "name": GraphQLField( - GraphQLNonNull(GraphQLString), - resolve=DirectiveResolvers.name, - ), - "description": GraphQLField( - GraphQLString, - resolve=DirectiveResolvers.description, - ), - "isRepeatable": GraphQLField( - GraphQLNonNull(GraphQLBoolean), - resolve=DirectiveResolvers.is_repeatable, - ), - "locations": GraphQLField( - GraphQLNonNull(GraphQLList(GraphQLNonNull(__DirectiveLocation))), - resolve=DirectiveResolvers.locations, - ), - "args": GraphQLField( - GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))), - args={ - "includeDeprecated": GraphQLArgument( - GraphQLBoolean, default_value=False - ) - }, - resolve=DirectiveResolvers.args, - ), - }, -) - +class DirectiveFields(GraphQLFieldMap): + def __new__(cls): + return { + # Note: The fields onOperation, onFragment and onField are deprecated + "name": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=cls.name, + ), + "description": GraphQLField( + GraphQLString, + resolve=cls.description, + ), + "isRepeatable": GraphQLField( + GraphQLNonNull(GraphQLBoolean), + resolve=cls.is_repeatable, + ), + "locations": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(_DirectiveLocation))), + resolve=cls.locations, + ), + "args": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(_InputValue))), + args={ + "includeDeprecated": GraphQLArgument( + GraphQLBoolean, default_value=False + ) + }, + resolve=cls.args, + ), + } -class DirectiveResolvers: @staticmethod def name(directive, _info): return directive.name @@ -144,7 +162,20 @@ def args(directive, _info, includeDeprecated=False): ) -__DirectiveLocation: GraphQLEnumType = GraphQLEnumType( +_Directive: GraphQLObjectType = GraphQLObjectType( + name="__Directive", + description="A Directive provides a way to describe alternate runtime" + " execution and type validation behavior in a GraphQL" + " document.\n\nIn some cases, you need to provide options" + " to alter GraphQL's execution behavior in ways field" + " arguments will not suffice, such as conditionally including" + " or skipping a field. Directives provide this by describing" + " additional information to the executor.", + fields=DirectiveFields, +) + + +_DirectiveLocation: GraphQLEnumType = GraphQLEnumType( name="__DirectiveLocation", description="A Directive can be adjacent to many parts of the GraphQL" " language, a __DirectiveLocation describes one such possible" @@ -229,65 +260,50 @@ def args(directive, _info, includeDeprecated=False): ) -__Type: GraphQLObjectType = GraphQLObjectType( - name="__Type", - description="The fundamental unit of any GraphQL Schema is the type." - " There are many kinds of types in GraphQL as represented" - " by the `__TypeKind` enum.\n\nDepending on the kind of a" - " type, certain fields describe information about that type." - " Scalar types provide no information beyond a name, description" - " and optional `specifiedByURL`, while Enum types provide their values." - " Object and Interface types provide the fields they describe." - " Abstract types, Union and Interface, provide the Object" - " types possible at runtime. List and NonNull types compose" - " other types.", - fields=lambda: { - "kind": GraphQLField(GraphQLNonNull(__TypeKind), resolve=TypeResolvers.kind), - "name": GraphQLField(GraphQLString, resolve=TypeResolvers.name), - "description": GraphQLField(GraphQLString, resolve=TypeResolvers.description), - "specifiedByURL": GraphQLField( - GraphQLString, resolve=TypeResolvers.specified_by_url - ), - "fields": GraphQLField( - GraphQLList(GraphQLNonNull(__Field)), - args={ - "includeDeprecated": GraphQLArgument( - GraphQLBoolean, default_value=False - ) - }, - resolve=TypeResolvers.fields, - ), - "interfaces": GraphQLField( - GraphQLList(GraphQLNonNull(__Type)), resolve=TypeResolvers.interfaces - ), - "possibleTypes": GraphQLField( - GraphQLList(GraphQLNonNull(__Type)), - resolve=TypeResolvers.possible_types, - ), - "enumValues": GraphQLField( - GraphQLList(GraphQLNonNull(__EnumValue)), - args={ - "includeDeprecated": GraphQLArgument( - GraphQLBoolean, default_value=False - ) - }, - resolve=TypeResolvers.enum_values, - ), - "inputFields": GraphQLField( - GraphQLList(GraphQLNonNull(__InputValue)), - args={ - "includeDeprecated": GraphQLArgument( - GraphQLBoolean, default_value=False - ) - }, - resolve=TypeResolvers.input_fields, - ), - "ofType": GraphQLField(__Type, resolve=TypeResolvers.of_type), - }, -) - +class TypeFields(GraphQLFieldMap): + def __new__(cls): + return { + "kind": GraphQLField(GraphQLNonNull(_TypeKind), resolve=cls.kind), + "name": GraphQLField(GraphQLString, resolve=cls.name), + "description": GraphQLField(GraphQLString, resolve=cls.description), + "specifiedByURL": GraphQLField(GraphQLString, resolve=cls.specified_by_url), + "fields": GraphQLField( + GraphQLList(GraphQLNonNull(_Field)), + args={ + "includeDeprecated": GraphQLArgument( + GraphQLBoolean, default_value=False + ) + }, + resolve=cls.fields, + ), + "interfaces": GraphQLField( + GraphQLList(GraphQLNonNull(_Type)), resolve=cls.interfaces + ), + "possibleTypes": GraphQLField( + GraphQLList(GraphQLNonNull(_Type)), + resolve=cls.possible_types, + ), + "enumValues": GraphQLField( + GraphQLList(GraphQLNonNull(_EnumValue)), + args={ + "includeDeprecated": GraphQLArgument( + GraphQLBoolean, default_value=False + ) + }, + resolve=cls.enum_values, + ), + "inputFields": GraphQLField( + GraphQLList(GraphQLNonNull(_InputValue)), + args={ + "includeDeprecated": GraphQLArgument( + GraphQLBoolean, default_value=False + ) + }, + resolve=cls.input_fields, + ), + "ofType": GraphQLField(_Type, resolve=cls.of_type), + } -class TypeResolvers: @staticmethod def kind(type_, _info): if is_scalar_type(type_): @@ -370,38 +386,46 @@ def of_type(type_, _info): return getattr(type_, "of_type", None) -__Field: GraphQLObjectType = GraphQLObjectType( - name="__Field", - description="Object and Interface types are described by a list of Fields," - " each of which has a name, potentially a list of arguments," - " and a return type.", - fields=lambda: { - "name": GraphQLField( - GraphQLNonNull(GraphQLString), resolve=FieldResolvers.name - ), - "description": GraphQLField(GraphQLString, resolve=FieldResolvers.description), - "args": GraphQLField( - GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))), - args={ - "includeDeprecated": GraphQLArgument( - GraphQLBoolean, default_value=False - ) - }, - resolve=FieldResolvers.args, - ), - "type": GraphQLField(GraphQLNonNull(__Type), resolve=FieldResolvers.type), - "isDeprecated": GraphQLField( - GraphQLNonNull(GraphQLBoolean), - resolve=FieldResolvers.is_deprecated, - ), - "deprecationReason": GraphQLField( - GraphQLString, resolve=FieldResolvers.deprecation_reason - ), - }, +_Type: GraphQLObjectType = GraphQLObjectType( + name="__Type", + description="The fundamental unit of any GraphQL Schema is the type." + " There are many kinds of types in GraphQL as represented" + " by the `__TypeKind` enum.\n\nDepending on the kind of a" + " type, certain fields describe information about that type." + " Scalar types provide no information beyond a name, description" + " and optional `specifiedByURL`, while Enum types provide their values." + " Object and Interface types provide the fields they describe." + " Abstract types, Union and Interface, provide the Object" + " types possible at runtime. List and NonNull types compose" + " other types.", + fields=TypeFields, ) -class FieldResolvers: +class FieldFields(GraphQLFieldMap): + def __new__(cls): + return { + "name": GraphQLField(GraphQLNonNull(GraphQLString), resolve=cls.name), + "description": GraphQLField(GraphQLString, resolve=cls.description), + "args": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(_InputValue))), + args={ + "includeDeprecated": GraphQLArgument( + GraphQLBoolean, default_value=False + ) + }, + resolve=cls.args, + ), + "type": GraphQLField(GraphQLNonNull(_Type), resolve=cls.type), + "isDeprecated": GraphQLField( + GraphQLNonNull(GraphQLBoolean), + resolve=cls.is_deprecated, + ), + "deprecationReason": GraphQLField( + GraphQLString, resolve=cls.deprecation_reason + ), + } + @staticmethod def name(item, _info): return item[0] @@ -433,39 +457,38 @@ def deprecation_reason(item, _info): return item[1].deprecation_reason -__InputValue: GraphQLObjectType = GraphQLObjectType( - name="__InputValue", - description="Arguments provided to Fields or Directives and the input" - " fields of an InputObject are represented as Input Values" - " which describe their type and optionally a default value.", - fields=lambda: { - "name": GraphQLField( - GraphQLNonNull(GraphQLString), resolve=InputValueFieldResolvers.name - ), - "description": GraphQLField( - GraphQLString, resolve=InputValueFieldResolvers.description - ), - "type": GraphQLField( - GraphQLNonNull(__Type), resolve=InputValueFieldResolvers.type - ), - "defaultValue": GraphQLField( - GraphQLString, - description="A GraphQL-formatted string representing" - " the default value for this input value.", - resolve=InputValueFieldResolvers.default_value, - ), - "isDeprecated": GraphQLField( - GraphQLNonNull(GraphQLBoolean), - resolve=InputValueFieldResolvers.is_deprecated, - ), - "deprecationReason": GraphQLField( - GraphQLString, resolve=InputValueFieldResolvers.deprecation_reason - ), - }, +_Field: GraphQLObjectType = GraphQLObjectType( + name="__Field", + description="Object and Interface types are described by a list of Fields," + " each of which has a name, potentially a list of arguments," + " and a return type.", + fields=FieldFields, ) -class InputValueFieldResolvers: +class InputValueFields(GraphQLFieldMap): + def __new__(cls): + return { + "name": GraphQLField(GraphQLNonNull(GraphQLString), resolve=cls.name), + "description": GraphQLField( + GraphQLString, resolve=InputValueFields.description + ), + "type": GraphQLField(GraphQLNonNull(_Type), resolve=cls.type), + "defaultValue": GraphQLField( + GraphQLString, + description="A GraphQL-formatted string representing" + " the default value for this input value.", + resolve=cls.default_value, + ), + "isDeprecated": GraphQLField( + GraphQLNonNull(GraphQLBoolean), + resolve=cls.is_deprecated, + ), + "deprecationReason": GraphQLField( + GraphQLString, resolve=cls.deprecation_reason + ), + } + @staticmethod def name(item, _info): return item[0] @@ -495,27 +518,57 @@ def deprecation_reason(item, _info): return item[1].deprecation_reason -__EnumValue: GraphQLObjectType = GraphQLObjectType( +_InputValue: GraphQLObjectType = GraphQLObjectType( + name="__InputValue", + description="Arguments provided to Fields or Directives and the input" + " fields of an InputObject are represented as Input Values" + " which describe their type and optionally a default value.", + fields=InputValueFields, +) + + +class EnumValueFields(GraphQLFieldMap): + def __new__(cls): + return { + "name": GraphQLField( + GraphQLNonNull(GraphQLString), resolve=EnumValueFields.name + ), + "description": GraphQLField( + GraphQLString, resolve=EnumValueFields.description + ), + "isDeprecated": GraphQLField( + GraphQLNonNull(GraphQLBoolean), + resolve=EnumValueFields.is_deprecated, + ), + "deprecationReason": GraphQLField( + GraphQLString, resolve=EnumValueFields.deprecation_reason + ), + } + + @staticmethod + def name(item, _info): + return item[0] + + @staticmethod + def description(item, _info): + return item[1].description + + @staticmethod + def is_deprecated(item, _info): + return item[1].deprecation_reason is not None + + @staticmethod + def deprecation_reason(item, _info): + return item[1].deprecation_reason + + +_EnumValue: GraphQLObjectType = GraphQLObjectType( name="__EnumValue", description="One possible value for a given Enum. Enum values are unique" " values, not a placeholder for a string or numeric value." " However an Enum value is returned in a JSON response as a" " string.", - fields=lambda: { - "name": GraphQLField( - GraphQLNonNull(GraphQLString), resolve=lambda item, _info: item[0] - ), - "description": GraphQLField( - GraphQLString, resolve=lambda item, _info: item[1].description - ), - "isDeprecated": GraphQLField( - GraphQLNonNull(GraphQLBoolean), - resolve=lambda item, _info: item[1].deprecation_reason is not None, - ), - "deprecationReason": GraphQLField( - GraphQLString, resolve=lambda item, _info: item[1].deprecation_reason - ), - }, + fields=EnumValueFields, ) @@ -530,7 +583,7 @@ class TypeKind(Enum): NON_NULL = "non-null" -__TypeKind: GraphQLEnumType = GraphQLEnumType( +_TypeKind: GraphQLEnumType = GraphQLEnumType( name="__TypeKind", description="An enum describing what kind of type a given `__Type` is.", values={ @@ -575,19 +628,33 @@ class TypeKind(Enum): ) +class MetaFields: + @staticmethod + def schema(_source, info): + return info.schema + + @staticmethod + def type(_source, info, **args): + return info.schema.get_type(args["name"]) + + @staticmethod + def type_name(_source, info, **_args): + return info.parent_type.name + + SchemaMetaFieldDef = GraphQLField( - GraphQLNonNull(__Schema), # name = '__schema' + GraphQLNonNull(_Schema), # name = '__schema' description="Access the current type schema of this server.", args={}, - resolve=lambda _source, info: info.schema, + resolve=MetaFields.schema, ) TypeMetaFieldDef = GraphQLField( - __Type, # name = '__type' + _Type, # name = '__type' description="Request the type information of a single type.", args={"name": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - resolve=lambda _source, info, **args: info.schema.get_type(args["name"]), + resolve=MetaFields.type, ) @@ -595,21 +662,21 @@ class TypeKind(Enum): GraphQLNonNull(GraphQLString), # name='__typename' description="The name of the current Object type at runtime.", args={}, - resolve=lambda _source, info, **_args: info.parent_type.name, + resolve=MetaFields.type_name, ) # Since double underscore names are subject to name mangling in Python, # the introspection classes are best imported via this dictionary: introspection_types: Mapping[str, GraphQLNamedType] = { # treat as read-only - "__Schema": __Schema, - "__Directive": __Directive, - "__DirectiveLocation": __DirectiveLocation, - "__Type": __Type, - "__Field": __Field, - "__InputValue": __InputValue, - "__EnumValue": __EnumValue, - "__TypeKind": __TypeKind, + "__Schema": _Schema, + "__Directive": _Directive, + "__DirectiveLocation": _DirectiveLocation, + "__Type": _Type, + "__Field": _Field, + "__InputValue": _InputValue, + "__EnumValue": _EnumValue, + "__TypeKind": _TypeKind, } """A mapping containing all introspection types with their names as keys""" @@ -617,3 +684,7 @@ class TypeKind(Enum): def is_introspection_type(type_: GraphQLNamedType) -> bool: """Check whether the given named GraphQL type is an introspection type.""" return type_.name in introspection_types + + +# register the introspection types to avoid redefinition +GraphQLNamedType.reserved_types.update(introspection_types) diff --git a/src/graphql/type/scalars.py b/src/graphql/type/scalars.py index 1271e27b..3f7263c1 100644 --- a/src/graphql/type/scalars.py +++ b/src/graphql/type/scalars.py @@ -320,3 +320,7 @@ def parse_id_literal(value_node: ValueNode, _variables: Any = None) -> str: def is_specified_scalar_type(type_: GraphQLNamedType) -> bool: """Check whether the given named GraphQL type is a specified scalar type.""" return type_.name in specified_scalar_types + + +# register the scalar types to avoid redefinition +GraphQLNamedType.reserved_types.update(specified_scalar_types) diff --git a/src/graphql/type/schema.py b/src/graphql/type/schema.py index 25e18d8b..c910fdc6 100644 --- a/src/graphql/type/schema.py +++ b/src/graphql/type/schema.py @@ -208,8 +208,8 @@ def __init__( # Provide specified directives (e.g. @include and @skip) by default self.directives = specified_directives if directives is None else directives - # To preserve order of user-provided types, we add first to add them to - # the set of "collected" types, so `collect_referenced_types` ignore them. + # To preserve order of user-provided types, we first add them to the set + # of "collected" types, so `collect_referenced_types` ignores them. if types: all_referenced_types = TypeSet.with_initial_types(types) collect_referenced_types = all_referenced_types.collect_referenced_types @@ -262,6 +262,7 @@ def __init__( "Schema must contain uniquely named types" f" but contains multiple types named '{type_name}'." ) + type_map[type_name] = named_type if is_interface_type(named_type): diff --git a/src/graphql/utilities/build_ast_schema.py b/src/graphql/utilities/build_ast_schema.py index d2597e0e..3f4fb804 100644 --- a/src/graphql/utilities/build_ast_schema.py +++ b/src/graphql/utilities/build_ast_schema.py @@ -7,7 +7,7 @@ GraphQLSchemaKwargs, specified_directives, ) -from .extend_schema import extend_schema_impl +from .extend_schema import ExtendSchemaImpl __all__ = [ @@ -57,7 +57,9 @@ def build_ast_schema( extension_ast_nodes=(), assume_valid=False, ) - schema_kwargs = extend_schema_impl(empty_schema_kwargs, document_ast, assume_valid) + schema_kwargs = ExtendSchemaImpl.extend_schema_args( + empty_schema_kwargs, document_ast, assume_valid + ) if not schema_kwargs["ast_node"]: for type_ in schema_kwargs["types"] or (): diff --git a/src/graphql/utilities/build_client_schema.py b/src/graphql/utilities/build_client_schema.py index e1c128fb..1f6694b1 100644 --- a/src/graphql/utilities/build_client_schema.py +++ b/src/graphql/utilities/build_client_schema.py @@ -134,11 +134,15 @@ def build_type(type_: IntrospectionType) -> GraphQLNamedType: def build_scalar_def( scalar_introspection: IntrospectionScalarType, ) -> GraphQLScalarType: - return GraphQLScalarType( - name=scalar_introspection["name"], - description=scalar_introspection.get("description"), - specified_by_url=scalar_introspection.get("specifiedByURL"), - ) + name = scalar_introspection["name"] + try: + return cast(GraphQLScalarType, GraphQLScalarType.reserved_types[name]) + except KeyError: + return GraphQLScalarType( + name=name, + description=scalar_introspection.get("description"), + specified_by_url=scalar_introspection.get("specifiedByURL"), + ) def build_implementations_list( implementing_introspection: Union[ @@ -161,12 +165,16 @@ def build_implementations_list( def build_object_def( object_introspection: IntrospectionObjectType, ) -> GraphQLObjectType: - return GraphQLObjectType( - name=object_introspection["name"], - description=object_introspection.get("description"), - interfaces=lambda: build_implementations_list(object_introspection), - fields=lambda: build_field_def_map(object_introspection), - ) + name = object_introspection["name"] + try: + return cast(GraphQLObjectType, GraphQLObjectType.reserved_types[name]) + except KeyError: + return GraphQLObjectType( + name=name, + description=object_introspection.get("description"), + interfaces=lambda: build_implementations_list(object_introspection), + fields=lambda: build_field_def_map(object_introspection), + ) def build_interface_def( interface_introspection: IntrospectionInterfaceType, @@ -200,18 +208,22 @@ def build_enum_def(enum_introspection: IntrospectionEnumType) -> GraphQLEnumType "Introspection result missing enumValues:" f" {inspect(enum_introspection)}." ) - return GraphQLEnumType( - name=enum_introspection["name"], - description=enum_introspection.get("description"), - values={ - value_introspect["name"]: GraphQLEnumValue( - value=value_introspect["name"], - description=value_introspect.get("description"), - deprecation_reason=value_introspect.get("deprecationReason"), - ) - for value_introspect in enum_introspection["enumValues"] - }, - ) + name = enum_introspection["name"] + try: + return cast(GraphQLEnumType, GraphQLEnumType.reserved_types[name]) + except KeyError: + return GraphQLEnumType( + name=name, + description=enum_introspection.get("description"), + values={ + value_introspect["name"]: GraphQLEnumValue( + value=value_introspect["name"], + description=value_introspect.get("description"), + deprecation_reason=value_introspect.get("deprecationReason"), + ) + for value_introspect in enum_introspection["enumValues"] + }, + ) def build_input_object_def( input_object_introspection: IntrospectionInputObjectType, diff --git a/src/graphql/utilities/extend_schema.py b/src/graphql/utilities/extend_schema.py index 37b9bd98..950b8740 100644 --- a/src/graphql/utilities/extend_schema.py +++ b/src/graphql/utilities/extend_schema.py @@ -1,13 +1,14 @@ from collections import defaultdict +from functools import partial from typing import ( Any, - Callable, Collection, DefaultDict, Dict, List, Mapping, Optional, + Tuple, Union, cast, ) @@ -88,7 +89,7 @@ __all__ = [ "extend_schema", - "extend_schema_impl", + "ExtendSchemaImpl", ] @@ -126,153 +127,221 @@ def extend_schema( assert_valid_sdl_extension(document_ast, schema) schema_kwargs = schema.to_kwargs() - extended_kwargs = extend_schema_impl(schema_kwargs, document_ast, assume_valid) + extended_kwargs = ExtendSchemaImpl.extend_schema_args( + schema_kwargs, document_ast, assume_valid + ) return ( schema if schema_kwargs is extended_kwargs else GraphQLSchema(**extended_kwargs) ) -def extend_schema_impl( - schema_kwargs: GraphQLSchemaKwargs, - document_ast: DocumentNode, - assume_valid: bool = False, -) -> GraphQLSchemaKwargs: - """Extend the given schema arguments with extensions from a given document. +class ExtendSchemaImpl: + """Helper class implementing the methods to extend a schema. + + Note: We use a class instead of an implementation with local functions + and lambda functions so that the extended schema can be pickled. For internal use only. """ - # Note: schema_kwargs should become a TypedDict once we require Python 3.8 - - # Collect the type definitions and extensions found in the document. - type_defs: List[TypeDefinitionNode] = [] - type_extensions_map: DefaultDict[str, Any] = defaultdict(list) - - # New directives and types are separate because a directives and types can have the - # same name. For example, a type named "skip". - directive_defs: List[DirectiveDefinitionNode] = [] - - schema_def: Optional[SchemaDefinitionNode] = None - # Schema extensions are collected which may add additional operation types. - schema_extensions: List[SchemaExtensionNode] = [] - - for def_ in document_ast.definitions: - if isinstance(def_, SchemaDefinitionNode): - schema_def = def_ - elif isinstance(def_, SchemaExtensionNode): - schema_extensions.append(def_) - elif isinstance(def_, TypeDefinitionNode): - type_defs.append(def_) - elif isinstance(def_, TypeExtensionNode): - extended_type_name = def_.name.value - type_extensions_map[extended_type_name].append(def_) - elif isinstance(def_, DirectiveDefinitionNode): - directive_defs.append(def_) - - # If this document contains no new types, extensions, or directives then return the - # same unmodified GraphQLSchema instance. - if ( - not type_extensions_map - and not type_defs - and not directive_defs - and not schema_extensions - and not schema_def - ): - return schema_kwargs - - # Below are functions used for producing this schema that have closed over this - # scope and have access to the schema, cache, and newly defined types. + + type_map: Dict[str, GraphQLNamedType] + type_extensions_map: Dict[str, Any] + + def __init__(self, type_extensions_map: Dict[str, Any]): + self.type_map = {} + self.type_extensions_map = type_extensions_map + + @classmethod + def extend_schema_args( + cls, + schema_kwargs: GraphQLSchemaKwargs, + document_ast: DocumentNode, + assume_valid: bool = False, + ) -> GraphQLSchemaKwargs: + """Extend the given schema arguments with extensions from a given document. + + For internal use only. + """ + # Note: schema_kwargs should become a TypedDict once we require Python 3.8 + + # Collect the type definitions and extensions found in the document. + type_defs: List[TypeDefinitionNode] = [] + type_extensions_map: DefaultDict[str, Any] = defaultdict(list) + + # New directives and types are separate because a directives and types can have + # the same name. For example, a type named "skip". + directive_defs: List[DirectiveDefinitionNode] = [] + + schema_def: Optional[SchemaDefinitionNode] = None + # Schema extensions are collected which may add additional operation types. + schema_extensions: List[SchemaExtensionNode] = [] + + for def_ in document_ast.definitions: + if isinstance(def_, SchemaDefinitionNode): + schema_def = def_ + elif isinstance(def_, SchemaExtensionNode): + schema_extensions.append(def_) + elif isinstance(def_, TypeDefinitionNode): + type_defs.append(def_) + elif isinstance(def_, TypeExtensionNode): + extended_type_name = def_.name.value + type_extensions_map[extended_type_name].append(def_) + elif isinstance(def_, DirectiveDefinitionNode): + directive_defs.append(def_) + + # If this document contains no new types, extensions, or directives then return + # the same unmodified GraphQLSchema instance. + if ( + not type_extensions_map + and not type_defs + and not directive_defs + and not schema_extensions + and not schema_def + ): + return schema_kwargs + + self = cls(type_extensions_map) + for existing_type in schema_kwargs["types"] or (): + self.type_map[existing_type.name] = self.extend_named_type(existing_type) + for type_node in type_defs: + name = type_node.name.value + self.type_map[name] = std_type_map.get(name) or self.build_type(type_node) + + # Get the extended root operation types. + operation_types: Dict[OperationType, GraphQLNamedType] = {} + for operation_type in OperationType: + original_type = schema_kwargs[operation_type.value] + if original_type: + operation_types[operation_type] = self.replace_named_type(original_type) + # Then, incorporate schema definition and all schema extensions. + if schema_def: + operation_types.update(self.get_operation_types([schema_def])) + if schema_extensions: + operation_types.update(self.get_operation_types(schema_extensions)) + + # Then produce and return the kwargs for a Schema with these types. + get_operation = operation_types.get + return GraphQLSchemaKwargs( + query=get_operation(OperationType.QUERY), # type: ignore + mutation=get_operation(OperationType.MUTATION), # type: ignore + subscription=get_operation(OperationType.SUBSCRIPTION), # type: ignore + types=tuple(self.type_map.values()), + directives=tuple( + self.replace_directive(directive) + for directive in schema_kwargs["directives"] + ) + + tuple(self.build_directive(directive) for directive in directive_defs), + description=schema_def.description.value + if schema_def and schema_def.description + else None, + extensions={}, + ast_node=schema_def or schema_kwargs["ast_node"], + extension_ast_nodes=schema_kwargs["extension_ast_nodes"] + + tuple(schema_extensions), + assume_valid=assume_valid, + ) # noinspection PyTypeChecker,PyUnresolvedReferences - def replace_type(type_: GraphQLType) -> GraphQLType: + def replace_type(self, type_: GraphQLType) -> GraphQLType: if is_list_type(type_): - return GraphQLList(replace_type(type_.of_type)) # type: ignore + return GraphQLList(self.replace_type(type_.of_type)) # type: ignore if is_non_null_type(type_): - return GraphQLNonNull(replace_type(type_.of_type)) # type: ignore - return replace_named_type(type_) # type: ignore + return GraphQLNonNull(self.replace_type(type_.of_type)) # type: ignore + return self.replace_named_type(type_) # type: ignore - def replace_named_type(type_: GraphQLNamedType) -> GraphQLNamedType: + def replace_named_type(self, type_: GraphQLNamedType) -> GraphQLNamedType: # Note: While this could make early assertions to get the correctly # typed values below, that would throw immediately while type system # validation with validate_schema() will produce more actionable results. - return type_map[type_.name] + return self.type_map[type_.name] # noinspection PyShadowingNames - def replace_directive(directive: GraphQLDirective) -> GraphQLDirective: + def replace_directive(self, directive: GraphQLDirective) -> GraphQLDirective: kwargs = directive.to_kwargs() return GraphQLDirective( **merge_kwargs( kwargs, - args={name: extend_arg(arg) for name, arg in kwargs["args"].items()}, + args={ + name: self.extend_arg(arg) for name, arg in kwargs["args"].items() + }, ) ) - def extend_named_type(type_: GraphQLNamedType) -> GraphQLNamedType: + def extend_named_type(self, type_: GraphQLNamedType) -> GraphQLNamedType: if is_introspection_type(type_) or is_specified_scalar_type(type_): # Builtin types are not extended. return type_ if is_scalar_type(type_): type_ = cast(GraphQLScalarType, type_) - return extend_scalar_type(type_) + return self.extend_scalar_type(type_) if is_object_type(type_): type_ = cast(GraphQLObjectType, type_) - return extend_object_type(type_) + return self.extend_object_type(type_) if is_interface_type(type_): type_ = cast(GraphQLInterfaceType, type_) - return extend_interface_type(type_) + return self.extend_interface_type(type_) if is_union_type(type_): type_ = cast(GraphQLUnionType, type_) - return extend_union_type(type_) + return self.extend_union_type(type_) if is_enum_type(type_): type_ = cast(GraphQLEnumType, type_) - return extend_enum_type(type_) + return self.extend_enum_type(type_) if is_input_object_type(type_): type_ = cast(GraphQLInputObjectType, type_) - return extend_input_object_type(type_) + return self.extend_input_object_type(type_) # Not reachable. All possible types have been considered. raise TypeError(f"Unexpected type: {inspect(type_)}.") # pragma: no cover + def extend_input_object_type_fields( + self, kwargs: Dict[str, Any], extensions: Tuple[Any, ...] + ) -> GraphQLInputFieldMap: + return { + **{ + name: GraphQLInputField( + **merge_kwargs( + field.to_kwargs(), + type_=self.replace_type(field.type), + ) + ) + for name, field in kwargs["fields"].items() + }, + **self.build_input_field_map(extensions), + } + # noinspection PyShadowingNames def extend_input_object_type( + self, type_: GraphQLInputObjectType, ) -> GraphQLInputObjectType: kwargs = type_.to_kwargs() - extensions = tuple(type_extensions_map[kwargs["name"]]) + extensions = tuple(self.type_extensions_map[kwargs["name"]]) return GraphQLInputObjectType( **merge_kwargs( kwargs, - fields=lambda: { - **{ - name: GraphQLInputField( - **merge_kwargs( - field.to_kwargs(), - type_=replace_type(field.type), - ) - ) - for name, field in kwargs["fields"].items() - }, - **build_input_field_map(extensions), - }, + fields=partial( + self.extend_input_object_type_fields, kwargs, extensions + ), extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) - def extend_enum_type(type_: GraphQLEnumType) -> GraphQLEnumType: + def extend_enum_type(self, type_: GraphQLEnumType) -> GraphQLEnumType: kwargs = type_.to_kwargs() - extensions = tuple(type_extensions_map[kwargs["name"]]) + extensions = tuple(self.type_extensions_map[kwargs["name"]]) return GraphQLEnumType( **merge_kwargs( kwargs, - values={**kwargs["values"], **build_enum_value_map(extensions)}, + values={**kwargs["values"], **self.build_enum_value_map(extensions)}, extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) - def extend_scalar_type(type_: GraphQLScalarType) -> GraphQLScalarType: + def extend_scalar_type(self, type_: GraphQLScalarType) -> GraphQLScalarType: kwargs = type_.to_kwargs() - extensions = tuple(type_extensions_map[kwargs["name"]]) + extensions = tuple(self.type_extensions_map[kwargs["name"]]) specified_by_url = kwargs["specified_by_url"] for extension_node in extensions: @@ -286,120 +355,148 @@ def extend_scalar_type(type_: GraphQLScalarType) -> GraphQLScalarType: ) ) + def extend_object_type_interfaces( + self, kwargs: Dict[str, Any], extensions: Tuple[Any, ...] + ) -> List[GraphQLInterfaceType]: + return [ + cast(GraphQLInterfaceType, self.replace_named_type(interface)) + for interface in kwargs["interfaces"] + ] + self.build_interfaces(extensions) + + def extend_object_type_fields( + self, kwargs: Dict[str, Any], extensions: Tuple[Any, ...] + ) -> GraphQLFieldMap: + return { + **{ + name: self.extend_field(field) + for name, field in kwargs["fields"].items() + }, + **self.build_field_map(extensions), + } + # noinspection PyShadowingNames - def extend_object_type(type_: GraphQLObjectType) -> GraphQLObjectType: + def extend_object_type(self, type_: GraphQLObjectType) -> GraphQLObjectType: kwargs = type_.to_kwargs() - extensions = tuple(type_extensions_map[kwargs["name"]]) + extensions = tuple(self.type_extensions_map[kwargs["name"]]) return GraphQLObjectType( **merge_kwargs( kwargs, - interfaces=lambda: [ - cast(GraphQLInterfaceType, replace_named_type(interface)) - for interface in kwargs["interfaces"] - ] - + build_interfaces(extensions), - fields=lambda: { - **{ - name: extend_field(field) - for name, field in kwargs["fields"].items() - }, - **build_field_map(extensions), - }, + interfaces=partial( + self.extend_object_type_interfaces, kwargs, extensions + ), + fields=partial(self.extend_object_type_fields, kwargs, extensions), extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) + def extend_interface_type_interfaces( + self, kwargs: Dict[str, Any], extensions: Tuple[Any, ...] + ) -> List[GraphQLInterfaceType]: + return [ + cast(GraphQLInterfaceType, self.replace_named_type(interface)) + for interface in kwargs["interfaces"] + ] + self.build_interfaces(extensions) + + def extend_interface_type_fields( + self, kwargs: Dict[str, Any], extensions: Tuple[Any, ...] + ) -> GraphQLFieldMap: + return { + **{ + name: self.extend_field(field) + for name, field in kwargs["fields"].items() + }, + **self.build_field_map(extensions), + } + # noinspection PyShadowingNames - def extend_interface_type(type_: GraphQLInterfaceType) -> GraphQLInterfaceType: + def extend_interface_type( + self, type_: GraphQLInterfaceType + ) -> GraphQLInterfaceType: kwargs = type_.to_kwargs() - extensions = tuple(type_extensions_map[kwargs["name"]]) + extensions = tuple(self.type_extensions_map[kwargs["name"]]) return GraphQLInterfaceType( **merge_kwargs( kwargs, - interfaces=lambda: [ - cast(GraphQLInterfaceType, replace_named_type(interface)) - for interface in kwargs["interfaces"] - ] - + build_interfaces(extensions), - fields=lambda: { - **{ - name: extend_field(field) - for name, field in kwargs["fields"].items() - }, - **build_field_map(extensions), - }, + interfaces=partial( + self.extend_interface_type_interfaces, kwargs, extensions + ), + fields=partial(self.extend_interface_type_fields, kwargs, extensions), extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, ) ) - def extend_union_type(type_: GraphQLUnionType) -> GraphQLUnionType: + def extend_union_type_types( + self, kwargs: Dict[str, Any], extensions: Tuple[Any, ...] + ) -> List[GraphQLObjectType]: + return [ + cast(GraphQLObjectType, self.replace_named_type(member_type)) + for member_type in kwargs["types"] + ] + self.build_union_types(extensions) + + def extend_union_type(self, type_: GraphQLUnionType) -> GraphQLUnionType: kwargs = type_.to_kwargs() - extensions = tuple(type_extensions_map[kwargs["name"]]) + extensions = tuple(self.type_extensions_map[kwargs["name"]]) return GraphQLUnionType( **merge_kwargs( kwargs, - types=lambda: [ - cast(GraphQLObjectType, replace_named_type(member_type)) - for member_type in kwargs["types"] - ] - + build_union_types(extensions), + types=partial(self.extend_union_type_types, kwargs, extensions), extension_ast_nodes=kwargs["extension_ast_nodes"] + extensions, - ) + ), ) # noinspection PyShadowingNames - def extend_field(field: GraphQLField) -> GraphQLField: + def extend_field(self, field: GraphQLField) -> GraphQLField: return GraphQLField( **merge_kwargs( field.to_kwargs(), - type_=replace_type(field.type), - args={name: extend_arg(arg) for name, arg in field.args.items()}, + type_=self.replace_type(field.type), + args={name: self.extend_arg(arg) for name, arg in field.args.items()}, ) ) - def extend_arg(arg: GraphQLArgument) -> GraphQLArgument: + def extend_arg(self, arg: GraphQLArgument) -> GraphQLArgument: return GraphQLArgument( **merge_kwargs( arg.to_kwargs(), - type_=replace_type(arg.type), + type_=self.replace_type(arg.type), ) ) # noinspection PyShadowingNames def get_operation_types( - nodes: Collection[Union[SchemaDefinitionNode, SchemaExtensionNode]] + self, nodes: Collection[Union[SchemaDefinitionNode, SchemaExtensionNode]] ) -> Dict[OperationType, GraphQLNamedType]: # Note: While this could make early assertions to get the correctly # typed values below, that would throw immediately while type system # validation with validate_schema() will produce more actionable results. return { - operation_type.operation: get_named_type(operation_type.type) + operation_type.operation: self.get_named_type(operation_type.type) for node in nodes for operation_type in node.operation_types or [] } # noinspection PyShadowingNames - def get_named_type(node: NamedTypeNode) -> GraphQLNamedType: + def get_named_type(self, node: NamedTypeNode) -> GraphQLNamedType: name = node.name.value - type_ = std_type_map.get(name) or type_map.get(name) + type_ = std_type_map.get(name) or self.type_map.get(name) if not type_: raise TypeError(f"Unknown type: '{name}'.") return type_ - def get_wrapped_type(node: TypeNode) -> GraphQLType: + def get_wrapped_type(self, node: TypeNode) -> GraphQLType: if isinstance(node, ListTypeNode): - return GraphQLList(get_wrapped_type(node.type)) + return GraphQLList(self.get_wrapped_type(node.type)) if isinstance(node, NonNullTypeNode): return GraphQLNonNull( - cast(GraphQLNullableType, get_wrapped_type(node.type)) + cast(GraphQLNullableType, self.get_wrapped_type(node.type)) ) - return get_named_type(cast(NamedTypeNode, node)) + return self.get_named_type(cast(NamedTypeNode, node)) - def build_directive(node: DirectiveDefinitionNode) -> GraphQLDirective: + def build_directive(self, node: DirectiveDefinitionNode) -> GraphQLDirective: locations = [DirectiveLocation[node.value] for node in node.locations] return GraphQLDirective( @@ -407,11 +504,12 @@ def build_directive(node: DirectiveDefinitionNode) -> GraphQLDirective: description=node.description.value if node.description else None, locations=locations, is_repeatable=node.repeatable, - args=build_argument_map(node.arguments), + args=self.build_argument_map(node.arguments), ast_node=node, ) def build_field_map( + self, nodes: Collection[ Union[ InterfaceTypeDefinitionNode, @@ -428,15 +526,16 @@ def build_field_map( # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. field_map[field.name.value] = GraphQLField( - type_=cast(GraphQLOutputType, get_wrapped_type(field.type)), + type_=cast(GraphQLOutputType, self.get_wrapped_type(field.type)), description=field.description.value if field.description else None, - args=build_argument_map(field.arguments), + args=self.build_argument_map(field.arguments), deprecation_reason=get_deprecation_reason(field), ast_node=field, ) return field_map def build_argument_map( + self, args: Optional[Collection[InputValueDefinitionNode]], ) -> GraphQLArgumentMap: arg_map: GraphQLArgumentMap = {} @@ -444,7 +543,7 @@ def build_argument_map( # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. - type_ = cast(GraphQLInputType, get_wrapped_type(arg.type)) + type_ = cast(GraphQLInputType, self.get_wrapped_type(arg.type)) arg_map[arg.name.value] = GraphQLArgument( type_=type_, description=arg.description.value if arg.description else None, @@ -455,6 +554,7 @@ def build_argument_map( return arg_map def build_input_field_map( + self, nodes: Collection[ Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode] ], @@ -465,7 +565,7 @@ def build_input_field_map( # Note: While this could make assertions to get the correctly typed # value, that would throw immediately while type system validation # with validate_schema() will produce more actionable results. - type_ = cast(GraphQLInputType, get_wrapped_type(field.type)) + type_ = cast(GraphQLInputType, self.get_wrapped_type(field.type)) input_field_map[field.name.value] = GraphQLInputField( type_=type_, description=field.description.value if field.description else None, @@ -475,6 +575,7 @@ def build_input_field_map( ) return input_field_map + @staticmethod def build_enum_value_map( nodes: Collection[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] ) -> GraphQLEnumValueMap: @@ -494,6 +595,7 @@ def build_enum_value_map( return enum_value_map def build_interfaces( + self, nodes: Collection[ Union[ InterfaceTypeDefinitionNode, @@ -503,29 +605,32 @@ def build_interfaces( ] ], ) -> List[GraphQLInterfaceType]: - interfaces: List[GraphQLInterfaceType] = [] - for node in nodes: - for type_ in node.interfaces or []: - # Note: While this could make assertions to get the correctly typed - # value, that would throw immediately while type system validation - # with validate_schema() will produce more actionable results. - interfaces.append(cast(GraphQLInterfaceType, get_named_type(type_))) - return interfaces + # Note: While this could make assertions to get the correctly typed + # value, that would throw immediately while type system validation + # with validate_schema() will produce more actionable results. + return [ + cast(GraphQLInterfaceType, self.get_named_type(type_)) + for node in nodes + for type_ in node.interfaces or [] + ] def build_union_types( + self, nodes: Collection[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]], ) -> List[GraphQLObjectType]: - types: List[GraphQLObjectType] = [] - for node in nodes: - for type_ in node.types or []: - # Note: While this could make assertions to get the correctly typed - # value, that would throw immediately while type system validation - # with validate_schema() will produce more actionable results. - types.append(cast(GraphQLObjectType, get_named_type(type_))) - return types + # Note: While this could make assertions to get the correctly typed + # value, that would throw immediately while type system validation + # with validate_schema() will produce more actionable results. + return [ + cast(GraphQLObjectType, self.get_named_type(type_)) + for node in nodes + for type_ in node.types or [] + ] - def build_object_type(ast_node: ObjectTypeDefinitionNode) -> GraphQLObjectType: - extension_nodes = type_extensions_map[ast_node.name.value] + def build_object_type( + self, ast_node: ObjectTypeDefinitionNode + ) -> GraphQLObjectType: + extension_nodes = self.type_extensions_map[ast_node.name.value] all_nodes: List[Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]] = [ ast_node, *extension_nodes, @@ -533,30 +638,31 @@ def build_object_type(ast_node: ObjectTypeDefinitionNode) -> GraphQLObjectType: return GraphQLObjectType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, - interfaces=lambda: build_interfaces(all_nodes), - fields=lambda: build_field_map(all_nodes), + interfaces=partial(self.build_interfaces, all_nodes), + fields=partial(self.build_field_map, all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) def build_interface_type( + self, ast_node: InterfaceTypeDefinitionNode, ) -> GraphQLInterfaceType: - extension_nodes = type_extensions_map[ast_node.name.value] + extension_nodes = self.type_extensions_map[ast_node.name.value] all_nodes: List[ Union[InterfaceTypeDefinitionNode, InterfaceTypeExtensionNode] ] = [ast_node, *extension_nodes] return GraphQLInterfaceType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, - interfaces=lambda: build_interfaces(all_nodes), - fields=lambda: build_field_map(all_nodes), + interfaces=partial(self.build_interfaces, all_nodes), + fields=partial(self.build_field_map, all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) - def build_enum_type(ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType: - extension_nodes = type_extensions_map[ast_node.name.value] + def build_enum_type(self, ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType: + extension_nodes = self.type_extensions_map[ast_node.name.value] all_nodes: List[Union[EnumTypeDefinitionNode, EnumTypeExtensionNode]] = [ ast_node, *extension_nodes, @@ -564,13 +670,13 @@ def build_enum_type(ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType: return GraphQLEnumType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, - values=build_enum_value_map(all_nodes), + values=self.build_enum_value_map(all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) - def build_union_type(ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType: - extension_nodes = type_extensions_map[ast_node.name.value] + def build_union_type(self, ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType: + extension_nodes = self.type_extensions_map[ast_node.name.value] all_nodes: List[Union[UnionTypeDefinitionNode, UnionTypeExtensionNode]] = [ ast_node, *extension_nodes, @@ -578,13 +684,15 @@ def build_union_type(ast_node: UnionTypeDefinitionNode) -> GraphQLUnionType: return GraphQLUnionType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, - types=lambda: build_union_types(all_nodes), + types=partial(self.build_union_types, all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) - def build_scalar_type(ast_node: ScalarTypeDefinitionNode) -> GraphQLScalarType: - extension_nodes = type_extensions_map[ast_node.name.value] + def build_scalar_type( + self, ast_node: ScalarTypeDefinitionNode + ) -> GraphQLScalarType: + extension_nodes = self.type_extensions_map[ast_node.name.value] return GraphQLScalarType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, @@ -594,83 +702,36 @@ def build_scalar_type(ast_node: ScalarTypeDefinitionNode) -> GraphQLScalarType: ) def build_input_object_type( + self, ast_node: InputObjectTypeDefinitionNode, ) -> GraphQLInputObjectType: - extension_nodes = type_extensions_map[ast_node.name.value] + extension_nodes = self.type_extensions_map[ast_node.name.value] all_nodes: List[ Union[InputObjectTypeDefinitionNode, InputObjectTypeExtensionNode] ] = [ast_node, *extension_nodes] return GraphQLInputObjectType( name=ast_node.name.value, description=ast_node.description.value if ast_node.description else None, - fields=lambda: build_input_field_map(all_nodes), + fields=partial(self.build_input_field_map, all_nodes), ast_node=ast_node, extension_ast_nodes=extension_nodes, ) - build_type_for_kind = cast( - Dict[str, Callable[[TypeDefinitionNode], GraphQLNamedType]], - { - "object_type_definition": build_object_type, - "interface_type_definition": build_interface_type, - "enum_type_definition": build_enum_type, - "union_type_definition": build_union_type, - "scalar_type_definition": build_scalar_type, - "input_object_type_definition": build_input_object_type, - }, - ) - - def build_type(ast_node: TypeDefinitionNode) -> GraphQLNamedType: + def build_type(self, ast_node: TypeDefinitionNode) -> GraphQLNamedType: + kind = ast_node.kind try: - # object_type_definition_node is built with _build_object_type etc. - build_function = build_type_for_kind[ast_node.kind] - except KeyError: # pragma: no cover + kind = kind.removesuffix("_definition") + except AttributeError: # pragma: no cover (Python < 3.9) + if kind.endswith("_definition"): + kind = kind[:-11] + try: + build = getattr(self, f"build_{kind}") + except AttributeError: # pragma: no cover # Not reachable. All possible type definition nodes have been considered. raise TypeError( # pragma: no cover f"Unexpected type definition node: {inspect(ast_node)}." ) - else: - return build_function(ast_node) - - type_map: Dict[str, GraphQLNamedType] = {} - for existing_type in schema_kwargs["types"] or (): - type_map[existing_type.name] = extend_named_type(existing_type) - for type_node in type_defs: - name = type_node.name.value - type_map[name] = std_type_map.get(name) or build_type(type_node) - - # Get the extended root operation types. - operation_types: Dict[OperationType, GraphQLNamedType] = {} - for operation_type in OperationType: - original_type = schema_kwargs[operation_type.value] - if original_type: - operation_types[operation_type] = replace_named_type(original_type) - # Then, incorporate schema definition and all schema extensions. - if schema_def: - operation_types.update(get_operation_types([schema_def])) - if schema_extensions: - operation_types.update(get_operation_types(schema_extensions)) - - # Then produce and return the kwargs for a Schema with these types. - get_operation = operation_types.get - return GraphQLSchemaKwargs( - query=get_operation(OperationType.QUERY), # type: ignore - mutation=get_operation(OperationType.MUTATION), # type: ignore - subscription=get_operation(OperationType.SUBSCRIPTION), # type: ignore - types=tuple(type_map.values()), - directives=tuple( - replace_directive(directive) for directive in schema_kwargs["directives"] - ) - + tuple(build_directive(directive) for directive in directive_defs), - description=schema_def.description.value - if schema_def and schema_def.description - else None, - extensions={}, - ast_node=schema_def or schema_kwargs["ast_node"], - extension_ast_nodes=schema_kwargs["extension_ast_nodes"] - + tuple(schema_extensions), - assume_valid=assume_valid, - ) + return build(ast_node) std_type_map: Mapping[str, Union[GraphQLNamedType, GraphQLObjectType]] = { diff --git a/tests/language/test_block_string_fuzz.py b/tests/language/test_block_string_fuzz.py index 8de96b22..e3a38f38 100644 --- a/tests/language/test_block_string_fuzz.py +++ b/tests/language/test_block_string_fuzz.py @@ -6,7 +6,7 @@ print_block_string, ) -from ..utils import dedent, gen_fuzz_strings +from ..utils import dedent, gen_fuzz_strings, timeout_factor def lex_value(s: str) -> str: @@ -42,7 +42,7 @@ def assert_non_printable_block_string(test_value: str) -> None: def describe_print_block_string(): @mark.slow - @mark.timeout(20) + @mark.timeout(80 * timeout_factor) def correctly_print_random_strings(): # Testing with length >7 is taking exponentially more time. However, it is # highly recommended testing with increased limit if you make any change. diff --git a/tests/language/test_schema_parser.py b/tests/language/test_schema_parser.py index 673f1554..feab6543 100644 --- a/tests/language/test_schema_parser.py +++ b/tests/language/test_schema_parser.py @@ -1,3 +1,5 @@ +import pickle +from copy import deepcopy from textwrap import dedent from typing import List, Optional, Tuple @@ -797,19 +799,36 @@ def directive_with_incorrect_locations(): def parses_kitchen_sink_schema(kitchen_sink_sdl): # noqa: F811 assert parse(kitchen_sink_sdl) - def can_pickle_and_unpickle_kitchen_sink_schema_ast(kitchen_sink_sdl): # noqa: F811 - import pickle - - # create a schema AST from the kitchen sink SDL - doc = parse(kitchen_sink_sdl) - # check that the schema AST can be pickled - # (particularly, there should be no recursion error) - dumped = pickle.dumps(doc) - # check that the pickle size is reasonable - assert len(dumped) < 50 * len(kitchen_sink_sdl) - loaded = pickle.loads(dumped) - # check that the un-pickled schema AST is still the same - assert loaded == doc - # check that pickling again creates the same result - dumped_again = pickle.dumps(doc) - assert dumped_again == dumped + def describe_deepcopy_and_pickle(): + def can_deep_copy_ast(kitchen_sink_sdl): # noqa: F811 + # create a schema AST from the kitchen sink SDL + doc = parse(kitchen_sink_sdl) + # make a deepcopy of the schema AST + copied_doc = deepcopy(doc) + # check that the copied AST is equal to the original one + assert copied_doc == doc + + def can_pickle_and_unpickle_ast(kitchen_sink_sdl): # noqa: F811 + # create a schema AST from the kitchen sink SDL + doc = parse(kitchen_sink_sdl) + # check that the schema AST can be pickled + # (particularly, there should be no recursion error) + dumped = pickle.dumps(doc) + # check that the pickle size is reasonable + assert len(dumped) < 50 * len(kitchen_sink_sdl) + loaded = pickle.loads(dumped) + # check that the un-pickled schema AST is still the same + assert loaded == doc + # check that pickling again creates the same result + dumped_again = pickle.dumps(doc) + assert dumped_again == dumped + + def can_deep_copy_pickled_ast(kitchen_sink_sdl): # noqa: F811 + # create a schema AST from the kitchen sink SDL + doc = parse(kitchen_sink_sdl) + # pickle and unpickle the schema AST + loaded_doc = pickle.loads(pickle.dumps(doc)) + # make a deepcopy of this + copied_doc = deepcopy(loaded_doc) + # check that the result is still equal to the original schema AST + assert copied_doc == doc diff --git a/tests/pyutils/test_undefined.py b/tests/pyutils/test_undefined.py index b7ad8cf6..9cd5303f 100644 --- a/tests/pyutils/test_undefined.py +++ b/tests/pyutils/test_undefined.py @@ -1,7 +1,11 @@ -from graphql.pyutils import Undefined +import pickle +from pytest import warns -def describe_invalid(): +from graphql.pyutils import Undefined, UndefinedType + + +def describe_Undefined(): def has_repr(): assert repr(Undefined) == "Undefined" @@ -26,3 +30,13 @@ def only_equal_to_itself(): false_object = False assert Undefined != false_object assert not Undefined == false_object + + def cannot_be_redefined(): + with warns(RuntimeWarning, match="Redefinition of 'Undefined'"): + redefined_undefined = UndefinedType() + assert redefined_undefined is Undefined + + def can_be_pickled(): + pickled_undefined = pickle.dumps(Undefined) + unpickled_undefined = pickle.loads(pickled_undefined) + assert unpickled_undefined is Undefined diff --git a/tests/type/test_definition.py b/tests/type/test_definition.py index 8515de89..24973086 100644 --- a/tests/type/test_definition.py +++ b/tests/type/test_definition.py @@ -1,3 +1,4 @@ +import pickle from enum import Enum from math import isnan, nan from typing import Dict, cast @@ -43,6 +44,7 @@ GraphQLScalarType, GraphQLString, GraphQLUnionType, + introspection_types, ) @@ -261,6 +263,17 @@ def rejects_a_scalar_type_with_incorrect_extension_ast_nodes(): " as a collection of ScalarTypeExtensionNode instances." ) + def pickles_a_custom_scalar_type(): + foo_type = GraphQLScalarType("Foo") + cycled_foo_type = pickle.loads(pickle.dumps(foo_type)) + assert cycled_foo_type.name == foo_type.name + assert cycled_foo_type is not foo_type + + def pickles_a_specified_scalar_type(): + cycled_int_type = pickle.loads(pickle.dumps(GraphQLInt)) + assert cycled_int_type.name == "Int" + assert cycled_int_type is GraphQLInt + def describe_type_system_fields(): def defines_a_field(): @@ -1903,3 +1916,11 @@ def fields_have_repr(): repr(GraphQLField(GraphQLList(GraphQLInt))) == ">>" ) + + +def describe_type_system_introspection_types(): + def cannot_redefine_introspection_types(): + for name, introspection_type in introspection_types.items(): + assert introspection_type.name == name + with raises(TypeError, match=f"Redefinition of reserved type '{name}'"): + introspection_type.__class__(**introspection_type.to_kwargs()) diff --git a/tests/type/test_scalars.py b/tests/type/test_scalars.py index c5413803..f2a45a67 100644 --- a/tests/type/test_scalars.py +++ b/tests/type/test_scalars.py @@ -1,3 +1,4 @@ +import pickle from math import inf, nan, pi from typing import Any @@ -11,6 +12,7 @@ GraphQLFloat, GraphQLID, GraphQLInt, + GraphQLScalarType, GraphQLString, ) @@ -172,6 +174,13 @@ def serializes(): serialize([5]) assert str(exc_info.value) == "Int cannot represent non-integer value: [5]" + def cannot_be_redefined(): + with raises(TypeError, match="Redefinition of reserved type 'Int'"): + GraphQLScalarType(name="Int") + + def pickles(): + assert pickle.loads(pickle.dumps(GraphQLInt)) is GraphQLInt + def describe_graphql_float(): def parse_value(): _parse_value = GraphQLFloat.parse_value @@ -295,6 +304,13 @@ def serializes(): str(exc_info.value) == "Float cannot represent non numeric value: [5]" ) + def cannot_be_redefined(): + with raises(TypeError, match="Redefinition of reserved type 'Float'"): + GraphQLScalarType(name="Float") + + def pickles(): + assert pickle.loads(pickle.dumps(GraphQLFloat)) is GraphQLFloat + def describe_graphql_string(): def parse_value(): _parse_value = GraphQLString.parse_value @@ -401,6 +417,13 @@ def __str__(self): " {'value_of': 'value_of string'}" ) + def cannot_be_redefined(): + with raises(TypeError, match="Redefinition of reserved type 'String'"): + GraphQLScalarType(name="String") + + def pickles(): + assert pickle.loads(pickle.dumps(GraphQLString)) is GraphQLString + def describe_graphql_boolean(): def parse_value(): _parse_value = GraphQLBoolean.parse_value @@ -543,6 +566,13 @@ def serializes(): "Boolean cannot represent a non boolean value: {}" ) + def cannot_be_redefined(): + with raises(TypeError, match="Redefinition of reserved type 'Boolean'"): + GraphQLScalarType(name="Boolean") + + def pickles(): + assert pickle.loads(pickle.dumps(GraphQLBoolean)) is GraphQLBoolean + def describe_graphql_id(): def parse_value(): _parse_value = GraphQLID.parse_value @@ -663,3 +693,10 @@ def __str__(self): with raises(GraphQLError) as exc_info: serialize(["abc"]) assert str(exc_info.value) == "ID cannot represent value: ['abc']" + + def cannot_be_redefined(): + with raises(TypeError, match="Redefinition of reserved type 'ID'"): + GraphQLScalarType(name="ID") + + def pickles(): + assert pickle.loads(pickle.dumps(GraphQLID)) is GraphQLID diff --git a/tests/type/test_schema.py b/tests/type/test_schema.py index 8dfc2c48..efd44f86 100644 --- a/tests/type/test_schema.py +++ b/tests/type/test_schema.py @@ -20,6 +20,7 @@ GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNamedType, GraphQLObjectType, GraphQLScalarType, GraphQLSchema, @@ -331,7 +332,15 @@ def check_that_query_mutation_and_subscription_are_graphql_types(): def describe_a_schema_must_contain_uniquely_named_types(): def rejects_a_schema_which_redefines_a_built_in_type(): - FakeString = GraphQLScalarType("String") + # temporarily allow redefinition of the String scalar type + reserved_types = GraphQLNamedType.reserved_types + GraphQLScalarType.reserved_types = {} + try: + # create a redefined String scalar type + FakeString = GraphQLScalarType("String") + finally: + # protect from redefinition again + GraphQLScalarType.reserved_types = reserved_types QueryType = GraphQLObjectType( "Query", diff --git a/tests/utilities/test_build_ast_schema.py b/tests/utilities/test_build_ast_schema.py index 1ae7ffb5..bb0dc561 100644 --- a/tests/utilities/test_build_ast_schema.py +++ b/tests/utilities/test_build_ast_schema.py @@ -1,4 +1,7 @@ +import pickle +import sys from collections import namedtuple +from copy import deepcopy from typing import Union from pytest import mark, raises @@ -35,7 +38,8 @@ from graphql.utilities import build_ast_schema, build_schema, print_schema, print_type from ..fixtures import big_schema_sdl # noqa: F401 -from ..utils import dedent +from ..star_wars_schema import star_wars_schema +from ..utils import dedent, timeout_factor def cycle_sdl(sdl: str) -> str: @@ -1186,28 +1190,110 @@ def rejects_invalid_ast(): build_ast_schema({}) # type: ignore assert str(exc_info.value) == "Must provide valid Document AST." - # This currently does not work because of how extend_schema is implemented - @mark.skip(reason="pickling of schemas is not yet supported") - def can_pickle_and_unpickle_big_schema( - big_schema_sdl, # noqa: F811 - ): # pragma: no cover - import pickle - - # create a schema from the kitchen sink SDL - schema = build_schema(big_schema_sdl, assume_valid_sdl=True) - # check that the schema can be pickled - # (particularly, there should be no recursion error, - # or errors because of trying to pickle lambdas or local functions) - dumped = pickle.dumps(schema) - # check that the pickle size is reasonable - assert len(dumped) < 50 * len(big_schema_sdl) - loaded = pickle.loads(dumped) - - # check that the un-pickled schema is still the same - assert loaded == schema - # check that pickling again creates the same result - dumped_again = pickle.dumps(schema) - assert dumped_again == dumped - - # check that printing the unpickled schema gives the same SDL - assert cycle_sdl(print_schema(schema)) == cycle_sdl(big_schema_sdl) + def describe_deepcopy_and_pickle(): # pragma: no cover + sdl = print_schema(star_wars_schema) + + def can_deep_copy_schema(): + schema = build_schema(sdl, assume_valid_sdl=True) + # create a deepcopy of the schema + copied = deepcopy(schema) + # check that printing the copied schema gives the same SDL + assert print_schema(copied) == sdl + + def can_pickle_and_unpickle_star_wars_schema(): + # create a schema from the star wars SDL + schema = build_schema(sdl, assume_valid_sdl=True) + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(schema) + + # check that the pickle size is reasonable + assert len(dumped) < 25 * len(sdl) + loaded = pickle.loads(dumped) + + # check that printing the unpickled schema gives the same SDL + assert print_schema(loaded) == sdl + + # check that pickling again creates the same result + dumped = pickle.dumps(schema) + assert len(dumped) < 25 * len(sdl) + loaded = pickle.loads(dumped) + assert print_schema(loaded) == sdl + + def can_deep_copy_pickled_schema(): + # create a schema from the star wars SDL + schema = build_schema(sdl, assume_valid_sdl=True) + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + # check that printing the copied schema gives the same SDL + assert print_schema(copied) == sdl + + @mark.slow + def describe_deepcopy_and_pickle_big(): # pragma: no cover + @mark.timeout(20 * timeout_factor) + def can_deep_copy_big_schema(big_schema_sdl): # noqa: F811 + # use our printing conventions + big_schema_sdl = cycle_sdl(big_schema_sdl) + + # create a schema from the big SDL + schema = build_schema(big_schema_sdl, assume_valid_sdl=True) + # create a deepcopy of the schema + copied = deepcopy(schema) + # check that printing the copied schema gives the same SDL + assert print_schema(copied) == big_schema_sdl + + @mark.timeout(60 * timeout_factor) + def can_pickle_and_unpickle_big_schema(big_schema_sdl): # noqa: F811 + # use our printing conventions + big_schema_sdl = cycle_sdl(big_schema_sdl) + + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # create a schema from the big SDL + schema = build_schema(big_schema_sdl, assume_valid_sdl=True) + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(schema) + + # check that the pickle size is reasonable + assert len(dumped) < 25 * len(big_schema_sdl) + loaded = pickle.loads(dumped) + + # check that printing the unpickled schema gives the same SDL + assert print_schema(loaded) == big_schema_sdl + + # check that pickling again creates the same result + dumped = pickle.dumps(schema) + assert len(dumped) < 25 * len(big_schema_sdl) + loaded = pickle.loads(dumped) + assert print_schema(loaded) == big_schema_sdl + + finally: + sys.setrecursionlimit(limit) + + @mark.timeout(60 * timeout_factor) + def can_deep_copy_pickled_big_schema(big_schema_sdl): # noqa: F811 + # use our printing conventions + big_schema_sdl = cycle_sdl(big_schema_sdl) + + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # create a schema from the big SDL + schema = build_schema(big_schema_sdl, assume_valid_sdl=True) + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + # check that printing the copied schema gives the same SDL + assert print_schema(copied) == big_schema_sdl + + finally: + sys.setrecursionlimit(limit) diff --git a/tests/utilities/test_introspection_from_schema.py b/tests/utilities/test_introspection_from_schema.py index 96ec968f..878ac0fb 100644 --- a/tests/utilities/test_introspection_from_schema.py +++ b/tests/utilities/test_introspection_from_schema.py @@ -1,12 +1,20 @@ +import pickle +import sys +from copy import deepcopy + +from pytest import mark + from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString from graphql.utilities import ( IntrospectionQuery, build_client_schema, + build_schema, introspection_from_schema, print_schema, ) -from ..utils import dedent +from ..fixtures import big_schema_introspection_result, big_schema_sdl # noqa: F401 +from ..utils import dedent, timeout_factor def introspection_to_sdl(introspection: IntrospectionQuery) -> str: @@ -60,3 +68,109 @@ def converts_a_simple_schema_without_description(): } """ ) + + def describe_deepcopy_and_pickle(): # pragma: no cover + # introspect the schema + introspected_schema = introspection_from_schema(schema) + introspection_size = len(str(introspected_schema)) + + def can_deep_copy_schema(): + # create a deepcopy of the schema + copied = deepcopy(schema) + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == introspected_schema + + def can_pickle_and_unpickle_schema(): + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(schema) + + # check that the pickle size is reasonable + assert len(dumped) < 5 * introspection_size + loaded = pickle.loads(dumped) + + # check that introspecting the unpickled schema gives the same result + assert introspection_from_schema(loaded) == introspected_schema + + # check that pickling again creates the same result + dumped = pickle.dumps(schema) + assert len(dumped) < 5 * introspection_size + loaded = pickle.loads(dumped) + assert introspection_from_schema(loaded) == introspected_schema + + def can_deep_copy_pickled_schema(): + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == introspected_schema + + @mark.slow + def describe_deepcopy_and_pickle_big(): # pragma: no cover + @mark.timeout(20 * timeout_factor) + def can_deep_copy_big_schema(big_schema_sdl): # noqa: F811 + # introspect the original big schema + big_schema = build_schema(big_schema_sdl) + expected_introspection = introspection_from_schema(big_schema) + + # create a deepcopy of the schema + copied = deepcopy(big_schema) + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == expected_introspection + + @mark.timeout(60 * timeout_factor) + def can_pickle_and_unpickle_big_schema(big_schema_sdl): # noqa: F811 + # introspect the original big schema + big_schema = build_schema(big_schema_sdl) + expected_introspection = introspection_from_schema(big_schema) + size_introspection = len(str(expected_introspection)) + + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(big_schema) + + # check that the pickle size is reasonable + assert len(dumped) < 5 * size_introspection + loaded = pickle.loads(dumped) + + # check that introspecting the pickled schema gives the same result + assert introspection_from_schema(loaded) == expected_introspection + + # check that pickling again creates the same result + dumped = pickle.dumps(loaded) + assert len(dumped) < 5 * size_introspection + loaded = pickle.loads(dumped) + + # check that introspecting the re-pickled schema gives the same result + assert introspection_from_schema(loaded) == expected_introspection + + finally: + sys.setrecursionlimit(limit) + + @mark.timeout(60 * timeout_factor) + def can_deep_copy_pickled_big_schema(big_schema_sdl): # noqa: F811 + # introspect the original big schema + big_schema = build_schema(big_schema_sdl) + expected_introspection = introspection_from_schema(big_schema) + + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(big_schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == expected_introspection + + finally: + sys.setrecursionlimit(limit) diff --git a/tests/utilities/test_print_schema.py b/tests/utilities/test_print_schema.py index 4bc5a266..1d60aa41 100644 --- a/tests/utilities/test_print_schema.py +++ b/tests/utilities/test_print_schema.py @@ -691,7 +691,7 @@ def prints_introspection_schema(): mutationType: __Type """ - If this server support subscription, the type that subscription operations will be rooted at. + If this server supports subscription, the type that subscription operations will be rooted at. """ subscriptionType: __Type diff --git a/tests/utilities/test_strip_ignored_characters_fuzz.py b/tests/utilities/test_strip_ignored_characters_fuzz.py index 7f75b8eb..5b038ca2 100644 --- a/tests/utilities/test_strip_ignored_characters_fuzz.py +++ b/tests/utilities/test_strip_ignored_characters_fuzz.py @@ -7,7 +7,7 @@ from graphql.language import Lexer, Source, TokenKind from graphql.utilities import strip_ignored_characters -from ..utils import dedent, gen_fuzz_strings +from ..utils import dedent, gen_fuzz_strings, timeout_factor ignored_tokens = [ @@ -228,7 +228,7 @@ def does_not_strip_random_ignored_tokens_embedded_in_the_block_string(): ).to_stay_the_same() @mark.slow - @mark.timeout(20) + @mark.timeout(80 * timeout_factor) def strips_ignored_characters_inside_random_block_strings(): # Testing with length >7 is taking exponentially more time. However it is # highly recommended to test with increased limit if you make any change. diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index d6392286..7657950a 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,7 +1,13 @@ """Test utilities""" +from platform import python_implementation + from .dedent import dedent from .gen_fuzz_strings import gen_fuzz_strings -__all__ = ["dedent", "gen_fuzz_strings"] +# some tests can take much longer on PyPy +timeout_factor = 4 if python_implementation() == "PyPy" else 1 + + +__all__ = ["dedent", "gen_fuzz_strings", "timeout_factor"] diff --git a/tox.ini b/tox.ini index 17bb0ae0..29f7de91 100644 --- a/tox.ini +++ b/tox.ini @@ -59,4 +59,6 @@ deps = pytest-timeout>=2,<3 py37: typing-extensions>=4.3,<5 commands = + # to also run the time-consuming tests: tox -e py310 -- --run-slow + # to run the benchmarks: tox -e py310 -- -k benchmarks --benchmark-enable pytest tests {posargs: --cov-report=term-missing --cov=graphql --cov=tests --cov-fail-under=100}