@@ -119,18 +119,18 @@ class GraphqlWsConsumer(ch_websocket.AsyncJsonWebsocketConsumer):
119119 # confirmation is enabled.
120120 subscription_confirmation_message : Dict [str , Any ] = {"data" : None , "errors" : None }
121121
122- # Issue a warning to the log when operation/resolver takes longer
123- # than specified number in seconds. None disables the warning.
122+ # Issue a warning to the log when operation takes longer than
123+ # specified number in seconds. None disables the warning.
124124 warn_operation_timeout : Optional [float ] = 1
125- warn_resolver_timeout : Optional [float ] = 1
126125
127126 # The size of the subscription notification queue. If there are more
128127 # notifications (for a single subscription) than the given number,
129128 # then an oldest notification is dropped and a warning is logged.
130129 subscription_notification_queue_limit : int = 1024
131130
132131 # GraphQL middleware.
133- # List of functions (callables) like the following:
132+ # Instance of `graphql.MiddlewareManager` or the list of functions
133+ # (callables) like the following:
134134 # ```python
135135 # async def my_middleware(next_middleware, root, info, *args, **kwds):
136136 # result = next_middleware(root, info, *args, **kwds)
@@ -145,7 +145,7 @@ class GraphqlWsConsumer(ch_websocket.AsyncJsonWebsocketConsumer):
145145 # - https://graphql-core-3.readthedocs.io/en/latest/diffs.html#custom-middleware
146146 # Docs about async middlewares are still missing - read the
147147 # GraphQL-core sources to know more.
148- middleware : Sequence = []
148+ middleware : Optional [ graphql . Middleware ] = None
149149
150150 # Subscription implementation shall return this to tell consumer
151151 # to suppress subscription notification.
@@ -278,6 +278,15 @@ def __init__(self, *args, **kwargs):
278278 weakref .WeakValueDictionary ()
279279 )
280280
281+ # MiddlewareManager maintains internal cache for resolvers
282+ # wrapped with middlewares. Using the same manager for all
283+ # operations improves performance.
284+ self ._middleware = None
285+ if self .middleware :
286+ self ._middleware = self .middleware
287+ if not isinstance (self ._middleware , graphql .MiddlewareManager ):
288+ self ._middleware = graphql .MiddlewareManager (* self ._middleware )
289+
281290 super ().__init__ (* args , ** kwargs )
282291
283292 # ---------------------------------------------------------- CONSUMER EVENT HANDLERS
@@ -595,27 +604,6 @@ async def _on_gql_start(self, op_id, payload):
595604 assert doc_ast is not None
596605 assert op_ast is not None
597606
598- async def unbound_root_middleware (* args , ** kwds ):
599- """Unbound function for root middleware.
600-
601- `graphql.MiddlewareManager` accepts only unbound
602- functions as middleware.
603- """
604- return await self ._on_gql_start__root_middleware (
605- op_id , op_name , * args , ** kwds
606- )
607-
608- # NOTE: Middlewares order is important, root middleware
609- # should always be the farest from the real resolver (last
610- # in the middleware list). Because we want to calculate
611- # resolver execution time with middlewares included.
612- middlewares = list (self .middleware )
613- if self .warn_resolver_timeout is not None :
614- middlewares .append (unbound_root_middleware )
615- middleware_manager : Optional [graphql .MiddlewareManager ] = None
616- if middlewares :
617- middleware_manager = graphql .MiddlewareManager (* middlewares )
618-
619607 # If the operation is subscription.
620608 if op_ast .operation == graphql .language .ast .OperationType .SUBSCRIPTION :
621609 LOG .debug (
@@ -637,7 +625,7 @@ async def unbound_root_middleware(*args, **kwds):
637625 op_id ,
638626 op_name ,
639627 ),
640- middleware = middleware_manager ,
628+ middleware = self . _middleware ,
641629 execution_context_class = self ._SubscriptionExecutionContext ,
642630 )
643631
@@ -715,6 +703,7 @@ async def consume_stream():
715703 # equals to `__schema`. This is a more robust way. But
716704 # it will eat up more CPU pre each query. For now lets
717705 # check only a query name.
706+ middleware_manager = self ._middleware
718707 if op_name == "IntrospectionQuery" :
719708 # No need to call middlewares for the
720709 # IntrospectionQuery. There no real resolvers. Only
@@ -901,82 +890,6 @@ async def map_source_to_response(payload: Any) -> graphql.ExecutionResult:
901890 # Map every source value to a ExecutionResult value.
902891 return graphql .MapAsyncIterator (result_or_stream , map_source_to_response )
903892
904- async def _on_gql_start__root_middleware (
905- self ,
906- operation_id : int ,
907- operation_name : str ,
908- next_middleware ,
909- root ,
910- info : graphql .GraphQLResolveInfo ,
911- * args ,
912- ** kwds ,
913- ):
914- """Root middleware injected right before resolver invocation.
915-
916- This middleware issues a warning if resolver execution time
917- exceeds a limit.
918-
919- Since this middleware always comes first in the list of
920- middlewares, it always receives resolver as the first
921- argument instead of another middleware.
922-
923- This is a part of START message processing routine so the name
924- prefixed with `_on_gql_start__` to make this explicit.
925-
926- Args:
927- resolver: Resolver to "wrap" into this middleware
928- root: Anything. Eventually passed to the resolver.
929- info: Passed to the resolver.
930-
931- Returns:
932- Any value: result returned by the resolver.
933- AsyncGenerator: when subscription starts.
934- """
935-
936- # Unwrap resolver from functools.partial or other wrappers.
937- real_resolver = self ._on_gql_start__unwrap (next_middleware )
938-
939- # Start measuring resolver execution time.
940- if self .warn_resolver_timeout is not None :
941- start_time = time .perf_counter ()
942-
943- # Execute resolver.
944- result = next_middleware (root , info , * args , ** kwds )
945- if inspect .isawaitable (result ):
946- result = await result
947-
948- # Warn about long resolver execution if the time limit exceeds.
949- if self .warn_resolver_timeout is not None :
950- duration = time .perf_counter () - start_time
951- if duration >= self .warn_resolver_timeout :
952- pretty_name = f"{ real_resolver .__qualname__ } "
953- if hasattr (real_resolver , "__self__" ):
954- pretty_name = f"{ real_resolver .__self__ .__qualname__ } .{ pretty_name } "
955- LOG .warning (
956- "Resolver %s took %.3f seconds (>%.3f)!"
957- " Operation %s(%s), path: %s." ,
958- pretty_name ,
959- duration ,
960- self .warn_resolver_timeout ,
961- operation_name ,
962- operation_id ,
963- info .path ,
964- )
965-
966- return result
967-
968- def _on_gql_start__unwrap (self , fn : Callable ) -> Callable :
969- """Auxiliary method which unwraps given function.
970-
971- This is a part of START message processing routine so the name
972- prefixed with `_on_gql_start__` to make this explicit.
973- """
974- if isinstance (fn , functools .partial ):
975- fn = self ._on_gql_start__unwrap (fn .func )
976- elif hasattr (fn , "__wrapped__" ):
977- fn = self ._on_gql_start__unwrap (fn .__wrapped__ )
978- return fn
979-
980893 async def _on_gql_start__initialize_subscription_stream (
981894 self ,
982895 operation_id : int ,
0 commit comments