diff --git a/derive/src/subscription.rs b/derive/src/subscription.rs index dbb9c2fa1..fa6be70c1 100644 --- a/derive/src/subscription.rs +++ b/derive/src/subscription.rs @@ -311,46 +311,50 @@ pub fn generate( let field = ::std::clone::Clone::clone(&field); let field_name = ::std::clone::Clone::clone(&field_name); async move { - let ctx_selection_set = query_env.create_context( - &schema_env, - ::std::option::Option::Some(#crate_name::QueryPathNode { - parent: ::std::option::Option::None, - segment: #crate_name::QueryPathSegment::Name(&field_name), - }), - &field.node.selection_set, - ); - - let execute_fut = async { - let parent_type = #gql_typename; - #[allow(bare_trait_objects)] - let ri = #crate_name::extensions::ResolveInfo { - path_node: ctx_selection_set.path_node.as_ref().unwrap(), - parent_type: &parent_type, - return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::OutputType>::qualified_type_name(), - name: field.node.name.node.as_str(), - alias: field.node.alias.as_ref().map(|alias| alias.node.as_str()), - is_for_introspection: false, - field: &field.node, - }; - let resolve_fut = async { - #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field) - .await - .map(::std::option::Option::Some) - }; - #crate_name::futures_util::pin_mut!(resolve_fut); - let mut resp = query_env.extensions.resolve(ri, &mut resolve_fut).await.map(|value| { - let mut map = #crate_name::indexmap::IndexMap::new(); - map.insert(::std::clone::Clone::clone(&field_name), value.unwrap_or_default()); - #crate_name::Response::new(#crate_name::Value::Object(map)) - }) - .unwrap_or_else(|err| #crate_name::Response::from_errors(::std::vec![err])); - - use ::std::iter::Extend; - resp.errors.extend(::std::mem::take(&mut *query_env.errors.lock().unwrap())); - resp + let f = |execute_data: ::std::option::Option<#crate_name::Data>| { + let schema_env = ::std::clone::Clone::clone(&schema_env); + let query_env = ::std::clone::Clone::clone(&query_env); + async move { + let ctx_selection_set = query_env.create_context( + &schema_env, + ::std::option::Option::Some(#crate_name::QueryPathNode { + parent: ::std::option::Option::None, + segment: #crate_name::QueryPathSegment::Name(&field_name), + }), + &field.node.selection_set, + execute_data.as_ref(), + ); + + let parent_type = #gql_typename; + #[allow(bare_trait_objects)] + let ri = #crate_name::extensions::ResolveInfo { + path_node: ctx_selection_set.path_node.as_ref().unwrap(), + parent_type: &parent_type, + return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::OutputType>::qualified_type_name(), + name: field.node.name.node.as_str(), + alias: field.node.alias.as_ref().map(|alias| alias.node.as_str()), + is_for_introspection: false, + field: &field.node, + }; + let resolve_fut = async { + #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field) + .await + .map(::std::option::Option::Some) + }; + #crate_name::futures_util::pin_mut!(resolve_fut); + let mut resp = query_env.extensions.resolve(ri, &mut resolve_fut).await.map(|value| { + let mut map = #crate_name::indexmap::IndexMap::new(); + map.insert(::std::clone::Clone::clone(&field_name), value.unwrap_or_default()); + #crate_name::Response::new(#crate_name::Value::Object(map)) + }) + .unwrap_or_else(|err| #crate_name::Response::from_errors(::std::vec![err])); + + use ::std::iter::Extend; + resp.errors.extend(::std::mem::take(&mut *query_env.errors.lock().unwrap())); + resp + } }; - #crate_name::futures_util::pin_mut!(execute_fut); - ::std::result::Result::Ok(query_env.extensions.execute(query_env.operation_name.as_deref(), &mut execute_fut).await) + ::std::result::Result::Ok(query_env.extensions.execute(query_env.operation_name.as_deref(), f).await) } } }); diff --git a/src/context.rs b/src/context.rs index 0bb253795..e84415ee3 100644 --- a/src/context.rs +++ b/src/context.rs @@ -240,6 +240,8 @@ pub struct ContextBase<'a, T> { pub schema_env: &'a SchemaEnv, #[doc(hidden)] pub query_env: &'a QueryEnv, + #[doc(hidden)] + pub execute_data: Option<&'a Data>, } #[doc(hidden)] @@ -251,8 +253,7 @@ pub struct QueryEnvInner { pub fragments: HashMap>, pub uploads: Vec, pub session_data: Arc, - pub ctx_data: Arc, - pub extension_data: Arc, + pub query_data: Arc, pub http_headers: Mutex, pub introspection_mode: IntrospectionMode, pub errors: Mutex>, @@ -282,6 +283,7 @@ impl QueryEnv { schema_env: &'a SchemaEnv, path_node: Option>, item: T, + execute_data: Option<&'a Data>, ) -> ContextBase<'a, T> { ContextBase { path_node, @@ -289,6 +291,7 @@ impl QueryEnv { item, schema_env, query_env: self, + execute_data, } } } @@ -322,6 +325,7 @@ impl<'a, T> ContextBase<'a, T> { item: field, schema_env: self.schema_env, query_env: self.query_env, + execute_data: self.execute_data.clone(), } } @@ -336,6 +340,7 @@ impl<'a, T> ContextBase<'a, T> { item: selection_set, schema_env: self.schema_env, query_env: self.query_env, + execute_data: self.execute_data.clone(), } } @@ -393,11 +398,10 @@ impl<'a, T> ContextBase<'a, T> { /// Gets the global data defined in the `Context` or `Schema` or `None` if /// the specified type data does not exist. pub fn data_opt(&self) -> Option<&'a D> { - self.query_env - .extension_data - .0 - .get(&TypeId::of::()) - .or_else(|| self.query_env.ctx_data.0.get(&TypeId::of::())) + self.execute_data + .as_ref() + .and_then(|execute_data| execute_data.get(&TypeId::of::())) + .or_else(|| self.query_env.query_data.0.get(&TypeId::of::())) .or_else(|| self.query_env.session_data.0.get(&TypeId::of::())) .or_else(|| self.schema_env.data.0.get(&TypeId::of::())) .and_then(|d| d.downcast_ref::()) @@ -611,6 +615,7 @@ impl<'a, T> ContextBase<'a, T> { item: self.item, schema_env: self.schema_env, query_env: self.query_env, + execute_data: self.execute_data.clone(), } } } diff --git a/src/dynamic/schema.rs b/src/dynamic/schema.rs index 99e28b281..f220a7aa1 100644 --- a/src/dynamic/schema.rs +++ b/src/dynamic/schema.rs @@ -306,9 +306,19 @@ impl Schema { self.0.env.registry.export_sdl(options) } - async fn execute_once(&self, env: QueryEnv, root_value: &FieldValue<'static>) -> Response { + async fn execute_once( + &self, + env: QueryEnv, + root_value: &FieldValue<'static>, + execute_data: Option, + ) -> Response { // execute - let ctx = env.create_context(&self.0.env, None, &env.operation.node.selection_set); + let ctx = env.create_context( + &self.0.env, + None, + &env.operation.node.selection_set, + execute_data.as_ref(), + ); let res = match &env.operation.node.ty { OperationType::Query => { async move { self.query_root() } @@ -361,14 +371,18 @@ impl Schema { .await { Ok((env, cache_control)) => { - let fut = async { - self.execute_once(env.clone(), &request.root_value) - .await - .cache_control(cache_control) + let f = { + |execute_data| { + let env = env.clone(); + async move { + self.execute_once(env, &request.root_value, execute_data) + .await + .cache_control(cache_control) + } + } }; - futures_util::pin_mut!(fut); env.extensions - .execute(env.operation_name.as_deref(), &mut fut) + .execute(env.operation_name.as_deref(), f) .await } Err(errors) => Response::from_errors(errors), @@ -420,7 +434,7 @@ impl Schema { }; if env.operation.node.ty != OperationType::Subscription { - yield schema.execute_once(env, &request.root_value).await; + yield schema.execute_once(env, &request.root_value, None).await; return; } @@ -428,6 +442,7 @@ impl Schema { &schema.0.env, None, &env.operation.node.selection_set, + None, ); let mut streams = Vec::new(); subscription.collect_streams(&schema, &ctx, &mut streams, &request.root_value); diff --git a/src/dynamic/subscription.rs b/src/dynamic/subscription.rs index 07c1c3514..14729a291 100644 --- a/src/dynamic/subscription.rs +++ b/src/dynamic/subscription.rs @@ -9,13 +9,7 @@ use crate::{ dynamic::{ resolve::resolve, FieldValue, InputValue, ObjectAccessor, ResolverContext, Schema, SchemaError, TypeRef, - }, - extensions::ResolveInfo, - parser::types::Selection, - registry::{Deprecation, MetaField, MetaType, Registry}, - subscription::BoxFieldStream, - ContextSelectionSet, Name, QueryPathNode, QueryPathSegment, Response, Result, ServerResult, - Value, + }, extensions::ResolveInfo, parser::types::Selection, registry::{Deprecation, MetaField, MetaType, Registry}, subscription::BoxFieldStream, ContextSelectionSet, Data, Name, QueryPathNode, QueryPathSegment, Response, Result, ServerResult, Value }; type BoxResolveFut<'a> = BoxFuture<'a, Result>>>>; @@ -233,34 +227,42 @@ impl Subscription { .map_err(|err| ctx_field.set_error_path(err.into_server_error(ctx_field.item.pos)))?; while let Some(value) = stream.next().await.transpose().map_err(|err| ctx_field.set_error_path(err.into_server_error(ctx_field.item.pos)))? { - let execute_fut = async { - let ri = ResolveInfo { - path_node: &QueryPathNode { - parent: None, - segment: QueryPathSegment::Name(&field_name), - }, - parent_type: schema.0.env.registry.subscription_type.as_ref().unwrap(), - return_type: &field_type.to_string(), - name: field.node.name.node.as_str(), - alias: field.node.alias.as_ref().map(|alias| alias.node.as_str()), - is_for_introspection: false, - field: &field.node, - }; - let resolve_fut = resolve(&schema, &ctx_field, &field_type, Some(&value)); - futures_util::pin_mut!(resolve_fut); - let value = ctx_field.query_env.extensions.resolve(ri, &mut resolve_fut).await; - - match value { - Ok(value) => { - let mut map = IndexMap::new(); - map.insert(field_name.clone(), value.unwrap_or_default()); - Response::new(Value::Object(map)) - }, - Err(err) => Response::from_errors(vec![err]), + let f = |execute_data: Option| { + let schema = schema.clone(); + let field_name = field_name.clone(); + let field_type = field_type.clone(); + let ctx_field = ctx_field.clone(); + + async move { + let mut ctx_field = ctx_field.clone(); + ctx_field.execute_data = execute_data.as_ref(); + let ri = ResolveInfo { + path_node: &QueryPathNode { + parent: None, + segment: QueryPathSegment::Name(&field_name), + }, + parent_type: schema.0.env.registry.subscription_type.as_ref().unwrap(), + return_type: &field_type.to_string(), + name: field.node.name.node.as_str(), + alias: field.node.alias.as_ref().map(|alias| alias.node.as_str()), + is_for_introspection: false, + field: &field.node, + }; + let resolve_fut = resolve(&schema, &ctx_field, &field_type, Some(&value)); + futures_util::pin_mut!(resolve_fut); + let value = ctx_field.query_env.extensions.resolve(ri, &mut resolve_fut).await; + + match value { + Ok(value) => { + let mut map = IndexMap::new(); + map.insert(field_name.clone(), value.unwrap_or_default()); + Response::new(Value::Object(map)) + }, + Err(err) => Response::from_errors(vec![err]), + } } }; - futures_util::pin_mut!(execute_fut); - let resp = ctx_field.query_env.extensions.execute(ctx_field.query_env.operation_name.as_deref(), &mut execute_fut).await; + let resp = ctx_field.query_env.extensions.execute(ctx_field.query_env.operation_name.as_deref(), f).await; let is_err = !resp.errors.is_empty(); yield resp; if is_err { diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 59d606fbf..6c47a23d8 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -18,7 +18,7 @@ use std::{ sync::Arc, }; -use futures_util::stream::BoxStream; +use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt}; pub use self::analyzer::Analyzer; #[cfg(feature = "apollo_tracing")] @@ -141,7 +141,7 @@ type ParseFut<'a> = &'a mut (dyn Future = &'a mut (dyn Future>> + Send + Unpin); -type ExecuteFut<'a> = &'a mut (dyn Future + Send + Unpin); +type ExecuteFutFactory<'a> = Box) -> BoxFuture<'a, Response> + Send + 'a>; /// A future type used to resolve the field pub type ResolveFut<'a> = &'a mut (dyn Future>> + Send + Unpin); @@ -272,12 +272,27 @@ impl<'a> NextValidation<'a> { /// The remainder of a extension chain for execute. pub struct NextExecute<'a> { chain: &'a [Arc], - execute_fut: ExecuteFut<'a>, + execute_fut_factory: ExecuteFutFactory<'a>, + execute_data: Option, } impl<'a> NextExecute<'a> { - /// Call the [Extension::execute] function of next extension. - pub async fn run(self, ctx: &ExtensionContext<'_>, operation_name: Option<&str>) -> Response { + async fn internal_run( + self, + ctx: &ExtensionContext<'_>, + operation_name: Option<&str>, + data: Option, + ) -> Response { + let execute_data = match (self.execute_data, data) { + (Some(mut data1), Some(data2)) => { + data1.merge(data2); + Some(data1) + } + (Some(data), None) => Some(data), + (None, Some(data)) => Some(data), + (None, None) => None, + }; + if let Some((first, next)) = self.chain.split_first() { first .execute( @@ -285,14 +300,31 @@ impl<'a> NextExecute<'a> { operation_name, NextExecute { chain: next, - execute_fut: self.execute_fut, + execute_fut_factory: self.execute_fut_factory, + execute_data, }, ) .await } else { - self.execute_fut.await + (self.execute_fut_factory)(execute_data).await } } + + /// Call the [Extension::execute] function of next extension. + pub async fn run(self, ctx: &ExtensionContext<'_>, operation_name: Option<&str>) -> Response { + self.internal_run(ctx, operation_name, None).await + } + + /// Call the [Extension::execute] function of next extension with context + /// data. + pub async fn run_with_data( + self, + ctx: &ExtensionContext<'_>, + operation_name: Option<&str>, + data: Data, + ) -> Response { + self.internal_run(ctx, operation_name, Some(data)).await + } } /// The remainder of a extension chain for resolve. @@ -427,7 +459,7 @@ impl Extensions { } #[inline] - pub fn attach_query_data(&mut self, data: Arc) { + pub(crate) fn attach_query_data(&mut self, data: Arc) { self.query_data = Some(data); } @@ -491,14 +523,19 @@ impl Extensions { next.run(&self.create_context()).await } - pub async fn execute( - &self, + pub async fn execute<'a, 'b, F, T>( + &'a self, operation_name: Option<&str>, - execute_fut: ExecuteFut<'_>, - ) -> Response { + execute_fut_factory: F, + ) -> Response + where + F: FnOnce(Option) -> T + Send + 'a, + T: Future + Send + 'a, + { let next = NextExecute { chain: &self.extensions, - execute_fut, + execute_fut_factory: Box::new(|data| execute_fut_factory(data).boxed()), + execute_data: None, }; next.run(&self.create_context(), operation_name).await } diff --git a/src/resolver_utils/container.rs b/src/resolver_utils/container.rs index d00dfee64..10f381daf 100644 --- a/src/resolver_utils/container.rs +++ b/src/resolver_utils/container.rs @@ -275,6 +275,7 @@ impl<'a> Fields<'a> { item: directive, schema_env: ctx_field.schema_env, query_env: ctx_field.query_env, + execute_data: ctx_field.execute_data, }; let directive_instance = directive_factory .create(&ctx_directive, &directive.node)?; diff --git a/src/schema.rs b/src/schema.rs index fd3b3d859..f5905aa35 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -465,7 +465,7 @@ where ) } - async fn execute_once(&self, env: QueryEnv) -> Response { + async fn execute_once(&self, env: QueryEnv, execute_data: Option<&Data>) -> Response { // execute let ctx = ContextBase { path_node: None, @@ -473,6 +473,7 @@ where item: &env.operation.node.selection_set, schema_env: &self.0.env, query_env: &env, + execute_data, }; let res = match &env.operation.node.ty { @@ -523,14 +524,16 @@ where .await { Ok((env, cache_control)) => { - let fut = async { - self.execute_once(env.clone()) - .await - .cache_control(cache_control) + let f = |execute_data: Option| { + let env = env.clone(); + async move { + self.execute_once(env, execute_data.as_ref()) + .await + .cache_control(cache_control) + } }; - futures_util::pin_mut!(fut); env.extensions - .execute(env.operation_name.as_deref(), &mut fut) + .execute(env.operation_name.as_deref(), f) .await } Err(errors) => Response::from_errors(errors), @@ -581,7 +584,19 @@ where }; if env.operation.node.ty != OperationType::Subscription { - yield schema.execute_once(env).await.cache_control(cache_control); + let f = |execute_data: Option| { + let env = env.clone(); + let schema = schema.clone(); + async move { + schema.execute_once(env, execute_data.as_ref()) + .await + .cache_control(cache_control) + } + }; + yield env.extensions + .execute(env.operation_name.as_deref(), f) + .await + .cache_control(cache_control); return; } @@ -589,6 +604,7 @@ where &schema.0.env, None, &env.operation.node.selection_set, + None, ); let mut streams = Vec::new(); @@ -763,11 +779,10 @@ pub(crate) async fn prepare_request( complexity: Option, depth: Option, ) -> Result<(QueryEnv, CacheControl), Vec> { - let mut request = request; + let mut request = extensions.prepare_request(request).await?; let query_data = Arc::new(std::mem::take(&mut request.data)); extensions.attach_query_data(query_data.clone()); - let mut request = extensions.prepare_request(request).await?; let mut document = { let query = &request.query; let parsed_doc = request.parsed_query.take(); @@ -855,8 +870,7 @@ pub(crate) async fn prepare_request( fragments: document.fragments, uploads: request.uploads, session_data, - ctx_data: query_data, - extension_data: Arc::new(request.data), + query_data, http_headers: Default::default(), introspection_mode: request.introspection_mode, errors: Default::default(), diff --git a/tests/extension.rs b/tests/extension.rs index 0cc498614..22377e4b5 100644 --- a/tests/extension.rs +++ b/tests/extension.rs @@ -1,4 +1,7 @@ -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicI32, Ordering}, + Arc, +}; use async_graphql::{ extensions::{ @@ -329,3 +332,204 @@ pub async fn test_extension_call_order() { ); } } + +#[tokio::test] +pub async fn query_execute_with_data() { + struct MyExtensionImpl(T); + + #[async_trait::async_trait] + impl Extension for MyExtensionImpl + where + T: Copy + Sync + Send + 'static, + { + async fn execute( + &self, + ctx: &ExtensionContext<'_>, + operation_name: Option<&str>, + next: NextExecute<'_>, + ) -> Response { + let mut data = Data::default(); + data.insert(self.0); + next.run_with_data(ctx, operation_name, data).await + } + } + + struct Query; + + #[Object] + impl Query { + async fn value(&self, ctx: &Context<'_>) -> Result { + Ok(*ctx.data::()? as i64 + ctx.data::()?) + } + } + + struct MyExtension(T); + + impl ExtensionFactory for MyExtension + where + T: Copy + Sync + Send + 'static, + { + fn create(&self) -> Arc { + Arc::new(MyExtensionImpl(self.0)) + } + } + + let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .extension(MyExtension(100i32)) + .extension(MyExtension(200i64)) + .finish(); + let query = "{ value }"; + assert_eq!( + schema.execute(query).await.data, + value!({ + "value": 300 + }) + ); +} + +#[tokio::test] +pub async fn subscription_execute_with_data() { + type Logs = Arc>>; + + struct MyExtensionImpl { + counter: Arc, + } + + impl MyExtensionImpl { + async fn append_log(&self, ctx: &ExtensionContext<'_>, element: LogElement) { + ctx.data::().unwrap().lock().await.push(element); + } + } + + #[async_trait::async_trait] + impl Extension for MyExtensionImpl { + async fn execute( + &self, + ctx: &ExtensionContext<'_>, + operation_name: Option<&str>, + next: NextExecute<'_>, + ) -> Response { + let mut data = Data::default(); + + let current_counter = self.counter.fetch_add(1, Ordering::SeqCst); + data.insert(current_counter); + self.append_log(ctx, LogElement::PreHook(current_counter)) + .await; + let resp = next.run_with_data(ctx, operation_name, data).await; + self.append_log(ctx, LogElement::PostHook(current_counter)) + .await; + resp + } + } + + struct MyExtension { + counter: Arc, + } + + impl ExtensionFactory for MyExtension { + fn create(&self) -> Arc { + Arc::new(MyExtensionImpl { + counter: self.counter.clone(), + }) + } + } + + #[derive(Debug, Eq, PartialEq)] + enum LogElement { + PreHook(i32), + OuterAccess(i32), + InnerAccess(i32), + PostHook(i32), + } + + let logs = Logs::default(); + let message_counter = Arc::new(AtomicI32::new(0)); + + #[derive(Clone, Copy)] + struct Inner(i32); + + #[Object] + impl Inner { + async fn value(&self, ctx: &Context<'_>) -> i32 { + if let Some(logs) = ctx.data_opt::() { + logs.lock().await.push(LogElement::InnerAccess(self.0)); + } + self.0 + } + } + + #[derive(Clone, Copy)] + struct Outer(Inner); + + #[Object] + impl Outer { + async fn inner(&self, ctx: &Context<'_>) -> Inner { + if let Some(logs) = ctx.data_opt::() { + logs.lock().await.push(LogElement::OuterAccess(self.0 .0)); + } + self.0 + } + } + + struct Query; + + #[Object] + impl Query { + async fn value(&self) -> i64 { + 0 + } + } + + struct Subscription; + + #[Subscription] + impl Subscription { + async fn outers(&self) -> impl Stream { + futures_util::stream::iter(10..13).map(Inner).map(Outer) + } + } + + let schema: Schema = + Schema::build(Query, EmptyMutation, Subscription) + .data(logs.clone()) + .extension(MyExtension { + counter: message_counter.clone(), + }) + .finish(); + let mut stream = schema.execute_stream("subscription { outers { inner { value } } }"); + + for i in 10i32..13 { + assert_eq!( + Response::new(value!({ + "outers": { + "inner": { + "value": i + } + } + })), + stream.next().await.unwrap() + ); + } + + { + let logs = logs.lock().await; + assert_eq!( + *logs, + vec![ + LogElement::PreHook(0), + LogElement::OuterAccess(10), + LogElement::InnerAccess(10), + LogElement::PostHook(0), + LogElement::PreHook(1), + LogElement::OuterAccess(11), + LogElement::InnerAccess(11), + LogElement::PostHook(1), + LogElement::PreHook(2), + LogElement::OuterAccess(12), + LogElement::InnerAccess(12), + LogElement::PostHook(2), + ], + "Log mismatch" + ); + } +}