diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ecc8f79..0f62af9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,9 +22,29 @@ jobs: strategy: fail-fast: false matrix: - rust: ["stable", "beta", "nightly"] + rust: ["stable"] backend: ["postgres", "mysql", "sqlite"] - os: [ubuntu-latest, macos-13, macos-15, windows-2019] + os: + [ubuntu-latest, macos-13, macos-15, windows-latest, ubuntu-22.04-arm] + include: + - rust: "beta" + backend: "postgres" + os: "ubuntu-latest" + - rust: "beta" + backend: "sqlite" + os: "ubuntu-latest" + - rust: "beta" + backend: "mysql" + os: "ubuntu-latest" + - rust: "nightly" + backend: "postgres" + os: "ubuntu-latest" + - rust: "nightly" + backend: "sqlite" + os: "ubuntu-latest" + - rust: "nightly" + backend: "mysql" + os: "ubuntu-latest" runs-on: ${{ matrix.os }} steps: - name: Checkout sources @@ -43,7 +63,7 @@ jobs: - name: Set environment variables shell: bash - if: matrix.backend == 'postgres' && matrix.os == 'windows-2019' + if: matrix.backend == 'postgres' && matrix.os == 'windows-latest' run: | echo "AWS_LC_SYS_NO_ASM=1" @@ -55,7 +75,7 @@ jobs: echo "RUSTDOCFLAGS=-D warnings" >> $GITHUB_ENV - uses: ilammy/setup-nasm@v1 - if: matrix.backend == 'postgres' && matrix.os == 'windows-2019' + if: matrix.backend == 'postgres' && matrix.os == 'windows-latest' - name: Install postgres (Linux) if: runner.os == 'Linux' && matrix.backend == 'postgres' @@ -78,37 +98,8 @@ jobs: - name: Install sqlite (Linux) if: runner.os == 'Linux' && matrix.backend == 'sqlite' run: | - curl -fsS --retry 3 -o sqlite-autoconf-3400100.tar.gz https://www.sqlite.org/2022/sqlite-autoconf-3400100.tar.gz - tar zxf sqlite-autoconf-3400100.tar.gz - cd sqlite-autoconf-3400100 - CFLAGS="$CFLAGS -O2 -fno-strict-aliasing \ - -DSQLITE_DEFAULT_FOREIGN_KEYS=1 \ - -DSQLITE_SECURE_DELETE \ - -DSQLITE_ENABLE_COLUMN_METADATA \ - -DSQLITE_ENABLE_FTS3_PARENTHESIS \ - -DSQLITE_ENABLE_RTREE=1 \ - -DSQLITE_SOUNDEX=1 \ - -DSQLITE_ENABLE_UNLOCK_NOTIFY \ - -DSQLITE_OMIT_LOOKASIDE=1 \ - -DSQLITE_ENABLE_DBSTAT_VTAB \ - -DSQLITE_ENABLE_UPDATE_DELETE_LIMIT=1 \ - -DSQLITE_ENABLE_LOAD_EXTENSION \ - -DSQLITE_ENABLE_JSON1 \ - -DSQLITE_LIKE_DOESNT_MATCH_BLOBS \ - -DSQLITE_THREADSAFE=1 \ - -DSQLITE_ENABLE_FTS3_TOKENIZER=1 \ - -DSQLITE_MAX_SCHEMA_RETRY=25 \ - -DSQLITE_ENABLE_PREUPDATE_HOOK \ - -DSQLITE_ENABLE_SESSION \ - -DSQLITE_ENABLE_STMTVTAB \ - -DSQLITE_MAX_VARIABLE_NUMBER=250000" \ - ./configure --prefix=/usr \ - --enable-threadsafe \ - --enable-dynamic-extensions \ - --libdir=/usr/lib/x86_64-linux-gnu \ - --libexecdir=/usr/lib/x86_64-linux-gnu/sqlite3 - sudo make - sudo make install + sudo apt-get update + sudo apt-get install libsqlite3-dev echo "DATABASE_URL=/tmp/test.db" >> $GITHUB_ENV - name: Install postgres (MacOS) @@ -184,8 +175,9 @@ jobs: run: | choco install sqlite cd /D C:\ProgramData\chocolatey\lib\SQLite\tools - call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" lib /machine:x64 /def:sqlite3.def /out:sqlite3.lib + - name: Set variables for sqlite (Windows) if: runner.os == 'Windows' && matrix.backend == 'sqlite' shell: bash @@ -240,11 +232,11 @@ jobs: - name: Check formating run: cargo +stable fmt --all -- --check minimal_rust_version: - name: Check Minimal supported rust version (1.78.0) + name: Check Minimal supported rust version (1.84.0) runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@1.78.0 + - uses: dtolnay/rust-toolchain@1.84.0 - uses: dtolnay/rust-toolchain@nightly - uses: taiki-e/install-action@cargo-hack - uses: taiki-e/install-action@cargo-minimal-versions @@ -253,4 +245,13 @@ jobs: # has broken min-version dependencies # cannot test sqlite yet as that crate # as broken min-version dependencies as well - run: cargo +1.78.0 minimal-versions check -p diesel-async --features "postgres bb8 deadpool mobc" + run: cargo +1.84.0 minimal-versions check -p diesel-async --features "postgres bb8 deadpool mobc" + all_features_build: + name: Check all feature combination build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: taiki-e/install-action@cargo-hack + - name: Check feature combinations + run: cargo hack check --feature-powerset --no-dev-deps --depth 2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e87802..7b08a08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,17 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ ## [Unreleased] +## [0.6.1] - 2025-07-03 + +* Fix features for some dependencies + +## [0.6.0] - 2025-07-02 + +* Allow to control the statement cache size +* Minimize dependencies features +* Bump minimal supported mysql_async version to 0.36.0 +* Fixing a bug in how we tracked open transaction that could lead to dangling transactions is specific cases + ## [0.5.2] - 2024-11-26 * Fixed an issue around transaction cancellation that could lead to connection pools containing connections with dangling transactions @@ -87,4 +98,7 @@ in the pool should be checked if they are still valid [0.4.1]: https://github.com/weiznich/diesel_async/compare/v0.4.0...v0.4.1 [0.5.0]: https://github.com/weiznich/diesel_async/compare/v0.4.0...v0.5.0 [0.5.1]: https://github.com/weiznich/diesel_async/compare/v0.5.0...v0.5.1 -[Unreleased]: https://github.com/weiznich/diesel_async/compare/v0.5.1...main +[0.5.2]: https://github.com/weiznich/diesel_async/compare/v0.5.1...v0.5.2 +[0.6.0]: https://github.com/weiznich/diesel_async/compare/v0.5.2...v0.6.0 +[0.6.1]: https://github.com/weiznich/diesel_async/compare/v0.6.0...v0.6.1 +[Unreleased]: https://github.com/weiznich/diesel_async/compare/v0.6.1...main diff --git a/Cargo.toml b/Cargo.toml index a0657e5..be4df15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "diesel-async" -version = "0.5.2" +version = "0.6.1" authors = ["Georg Semmler "] edition = "2021" autotests = false @@ -10,60 +10,79 @@ repository = "https://github.com/weiznich/diesel_async" keywords = ["orm", "database", "sql", "async"] categories = ["database"] description = "An async extension for Diesel the safe, extensible ORM and Query Builder" -rust-version = "1.78.0" +rust-version = "1.84.0" [dependencies] -diesel = { version = "~2.2.0", default-features = false, features = [ - "i-implement-a-third-party-backend-and-opt-into-breaking-changes", -] } -async-trait = "0.1.66" +futures-core = "0.3.17" futures-channel = { version = "0.3.17", default-features = false, features = [ - "std", - "sink", + "std", + "sink", ], optional = true } futures-util = { version = "0.3.17", default-features = false, features = [ - "std", - "sink", + "alloc", + "sink", ] } tokio-postgres = { version = "0.7.10", optional = true } tokio = { version = "1.26", optional = true } -mysql_async = { version = "0.34", optional = true, default-features = false, features = [ - "minimal-rust", +mysql_async = { version = "0.36.0", optional = true, default-features = false, features = [ + "minimal-rust", ] } -mysql_common = { version = "0.32", optional = true, default-features = false } +mysql_common = { version = "0.35.3", optional = true, default-features = false } -bb8 = { version = "0.8", optional = true } +bb8 = { version = "0.9", optional = true } +async-trait = { version = "0.1.66", optional = true } deadpool = { version = "0.12", optional = true, default-features = false, features = [ - "managed", + "managed", ] } mobc = { version = ">=0.7,<0.10", optional = true } scoped-futures = { version = "0.1", features = ["std"] } +[dependencies.diesel] +version = "~2.2.0" +default-features = false +features = [ + "i-implement-a-third-party-backend-and-opt-into-breaking-changes", +] +git = "https://github.com/diesel-rs/diesel" +branch = "master" + [dev-dependencies] tokio = { version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"] } cfg-if = "1" chrono = "0.4" -diesel = { version = "2.2.0", default-features = false, features = ["chrono"] } -diesel_migrations = "2.2.0" assert_matches = "1.0.1" +[dev-dependencies.diesel] +version = "~2.2.0" +default-features = false +features = [ + "chrono" +] +git = "https://github.com/diesel-rs/diesel" +branch = "master" + +[dev-dependencies.diesel_migrations] +version = "2.2.0" +git = "https://github.com/diesel-rs/diesel" +branch = "master" + [features] default = [] mysql = [ - "diesel/mysql_backend", - "mysql_async", - "mysql_common", - "futures-channel", - "tokio", + "diesel/mysql_backend", + "mysql_async", + "mysql_common", + "futures-channel", + "tokio", ] postgres = ["diesel/postgres_backend", "tokio-postgres", "tokio", "tokio/rt"] sqlite = ["diesel/sqlite", "sync-connection-wrapper"] sync-connection-wrapper = ["tokio/rt"] -async-connection-wrapper = ["tokio/net"] +async-connection-wrapper = ["tokio/net", "tokio/rt"] pool = [] r2d2 = ["pool", "diesel/r2d2"] bb8 = ["pool", "dep:bb8"] -mobc = ["pool", "dep:mobc"] +mobc = ["pool", "dep:mobc", "dep:async-trait", "tokio/sync"] deadpool = ["pool", "dep:deadpool"] [[test]] @@ -73,15 +92,15 @@ harness = true [package.metadata.docs.rs] features = [ - "postgres", - "mysql", - "sqlite", - "deadpool", - "bb8", - "mobc", - "async-connection-wrapper", - "sync-connection-wrapper", - "r2d2", + "postgres", + "mysql", + "sqlite", + "deadpool", + "bb8", + "mobc", + "async-connection-wrapper", + "sync-connection-wrapper", + "r2d2", ] no-default-features = true rustc-args = ["--cfg", "docsrs"] @@ -89,8 +108,8 @@ rustdoc-args = ["--cfg", "docsrs"] [workspace] members = [ - ".", - "examples/postgres/pooled-with-rustls", - "examples/postgres/run-pending-migrations-with-rustls", - "examples/sync-wrapper", + ".", + "examples/postgres/pooled-with-rustls", + "examples/postgres/run-pending-migrations-with-rustls", + "examples/sync-wrapper", ] diff --git a/examples/postgres/pooled-with-rustls/Cargo.toml b/examples/postgres/pooled-with-rustls/Cargo.toml index 28c6093..a39754f 100644 --- a/examples/postgres/pooled-with-rustls/Cargo.toml +++ b/examples/postgres/pooled-with-rustls/Cargo.toml @@ -6,11 +6,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -diesel = { version = "2.2.0", default-features = false, features = ["postgres"] } -diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres"] } +diesel-async = { version = "0.6.0", path = "../../../", features = ["bb8", "postgres"] } futures-util = "0.3.21" rustls = "0.23.8" -rustls-native-certs = "0.7.1" +rustls-platform-verifier = "0.5.0" tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] } tokio-postgres = "0.7.7" -tokio-postgres-rustls = "0.12.0" +tokio-postgres-rustls = "0.13.0" + + +[dependencies.diesel] +version = "2.2.0" +default-features = false +git = "https://github.com/diesel-rs/diesel" +branch = "master" diff --git a/examples/postgres/pooled-with-rustls/src/main.rs b/examples/postgres/pooled-with-rustls/src/main.rs index 87a8eb4..c3a0fc5 100644 --- a/examples/postgres/pooled-with-rustls/src/main.rs +++ b/examples/postgres/pooled-with-rustls/src/main.rs @@ -5,6 +5,8 @@ use diesel_async::pooled_connection::ManagerConfig; use diesel_async::AsyncPgConnection; use futures_util::future::BoxFuture; use futures_util::FutureExt; +use rustls::ClientConfig; +use rustls_platform_verifier::ConfigVerifierExt; use std::time::Duration; #[tokio::main] @@ -23,7 +25,7 @@ async fn main() -> Result<(), Box> { // means this will check whether the provided certificate is valid for the given database host. // // `libpq` does not perform these checks by default (https://www.postgresql.org/docs/current/libpq-connect.html) - // If you hit a TLS error while conneting to the database double check your certificates + // If you hit a TLS error while connecting to the database double check your certificates let pool = Pool::builder() .max_size(10) .min_idle(Some(5)) @@ -39,12 +41,10 @@ async fn main() -> Result<(), Box> { Ok(()) } -fn establish_connection(config: &str) -> BoxFuture> { +fn establish_connection(config: &str) -> BoxFuture<'_, ConnectionResult> { let fut = async { // We first set up the way we want rustls to work. - let rustls_config = rustls::ClientConfig::builder() - .with_root_certificates(root_certs()) - .with_no_client_auth(); + let rustls_config = ClientConfig::with_platform_verifier(); let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config); let (client, conn) = tokio_postgres::connect(config, tls) .await @@ -54,10 +54,3 @@ fn establish_connection(config: &str) -> BoxFuture rustls::RootCertStore { - let mut roots = rustls::RootCertStore::empty(); - let certs = rustls_native_certs::load_native_certs().expect("Certs not loadable!"); - roots.add_parsable_certificates(certs); - roots -} diff --git a/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml b/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml index 2f54ab4..f9066f3 100644 --- a/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml +++ b/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml @@ -6,12 +6,21 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -diesel = { version = "2.2.0", default-features = false, features = ["postgres"] } -diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] } -diesel_migrations = "2.2.0" +diesel-async = { version = "0.6.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] } futures-util = "0.3.21" -rustls = "0.23.10" -rustls-native-certs = "0.7.1" +rustls = "0.23.8" +rustls-platform-verifier = "0.5.0" tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] } tokio-postgres = "0.7.7" -tokio-postgres-rustls = "0.12.0" +tokio-postgres-rustls = "0.13.0" + +[dependencies.diesel] +version = "2.2.0" +default-features = false +git = "https://github.com/diesel-rs/diesel" +branch = "master" + +[dependencies.diesel_migrations] +version = "2.2.0" +git = "https://github.com/diesel-rs/diesel" +branch = "master" diff --git a/examples/postgres/run-pending-migrations-with-rustls/src/main.rs b/examples/postgres/run-pending-migrations-with-rustls/src/main.rs index 1fb0c0f..6c0781c 100644 --- a/examples/postgres/run-pending-migrations-with-rustls/src/main.rs +++ b/examples/postgres/run-pending-migrations-with-rustls/src/main.rs @@ -4,6 +4,8 @@ use diesel_async::AsyncPgConnection; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use futures_util::future::BoxFuture; use futures_util::FutureExt; +use rustls::ClientConfig; +use rustls_platform_verifier::ConfigVerifierExt; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); @@ -25,12 +27,10 @@ async fn main() -> Result<(), Box> { Ok(()) } -fn establish_connection(config: &str) -> BoxFuture> { +fn establish_connection(config: &str) -> BoxFuture<'_, ConnectionResult> { let fut = async { // We first set up the way we want rustls to work. - let rustls_config = rustls::ClientConfig::builder() - .with_root_certificates(root_certs()) - .with_no_client_auth(); + let rustls_config = ClientConfig::with_platform_verifier(); let tls = tokio_postgres_rustls::MakeRustlsConnect::new(rustls_config); let (client, conn) = tokio_postgres::connect(config, tls) .await @@ -39,10 +39,3 @@ fn establish_connection(config: &str) -> BoxFuture rustls::RootCertStore { - let mut roots = rustls::RootCertStore::empty(); - let certs = rustls_native_certs::load_native_certs().expect("Certs not loadable!"); - roots.add_parsable_certificates(certs); - roots -} diff --git a/examples/sync-wrapper/Cargo.toml b/examples/sync-wrapper/Cargo.toml index d578028..667da14 100644 --- a/examples/sync-wrapper/Cargo.toml +++ b/examples/sync-wrapper/Cargo.toml @@ -6,12 +6,22 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -diesel = { version = "2.2.0", default-features = false, features = ["returning_clauses_for_sqlite_3_35"] } -diesel-async = { version = "0.5.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] } -diesel_migrations = "2.2.0" +diesel-async = { version = "0.6.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] } futures-util = "0.3.21" tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] } +[dependencies.diesel] +version = "2.2.0" +default-features = false +features = ["returning_clauses_for_sqlite_3_35"] +git = "https://github.com/diesel-rs/diesel" +branch = "master" + +[dependencies.diesel_migrations] +version = "2.2.0" +git = "https://github.com/diesel-rs/diesel" +branch = "master" + [features] default = ["sqlite"] sqlite = ["diesel-async/sqlite"] diff --git a/src/async_connection_wrapper.rs b/src/async_connection_wrapper.rs index 2bb0ae4..4e11078 100644 --- a/src/async_connection_wrapper.rs +++ b/src/async_connection_wrapper.rs @@ -9,9 +9,9 @@ //! as replacement for the existing connection //! implementations provided by diesel -use futures_util::Future; -use futures_util::Stream; +use futures_core::Stream; use futures_util::StreamExt; +use std::future::Future; use std::pin::Pin; /// This is a helper trait that allows to customize the @@ -100,7 +100,7 @@ pub type AsyncConnectionWrapper = pub use self::implementation::AsyncConnectionWrapper; mod implementation { - use diesel::connection::{Instrumentation, SimpleConnection}; + use diesel::connection::{CacheSize, Instrumentation, SimpleConnection}; use std::ops::{Deref, DerefMut}; use super::*; @@ -123,6 +123,17 @@ mod implementation { } } + impl AsyncConnectionWrapper + where + C: crate::AsyncConnection, + { + /// Consumes the [`AsyncConnectionWrapper`] returning the wrapped inner + /// [`AsyncConnection`]. + pub fn into_inner(self) -> C { + self.inner + } + } + impl Deref for AsyncConnectionWrapper { type Target = C; @@ -187,6 +198,10 @@ mod implementation { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.inner.set_instrumentation(instrumentation); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.inner.set_prepared_statement_cache_size(size) + } } impl diesel::connection::LoadConnection for AsyncConnectionWrapper @@ -194,13 +209,15 @@ mod implementation { C: crate::AsyncConnection, B: BlockOn + Send, { - type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B> - where - Self: 'conn; + type Cursor<'conn, 'query> + = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B> + where + Self: 'conn; - type Row<'conn, 'query> = C::Row<'conn, 'query> - where - Self: 'conn; + type Row<'conn, 'query> + = C::Row<'conn, 'query> + where + Self: 'conn; fn load<'conn, 'query, T>( &'conn mut self, @@ -228,7 +245,7 @@ mod implementation { runtime: &'a B, } - impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B> + impl Iterator for AsyncCursorWrapper<'_, S, B> where S: Stream, B: BlockOn, diff --git a/src/lib.rs b/src/lib.rs index 1a4b49c..8102312 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ //! //! * [`AsyncMysqlConnection`] (enabled by the `mysql` feature) //! * [`AsyncPgConnection`] (enabled by the `postgres` feature) -//! * [`SyncConnectionWrapper`] (enabled by the `sync-connection-wrapper`/`sqlite` feature) +//! * [`SyncConnectionWrapper`](sync_connection_wrapper::SyncConnectionWrapper) (enabled by the `sync-connection-wrapper`/`sqlite` feature) //! //! Ordinary usage of `diesel-async` assumes that you just replace the corresponding sync trait //! method calls and connections with their async counterparts. @@ -74,13 +74,15 @@ )] use diesel::backend::Backend; -use diesel::connection::Instrumentation; +use diesel::connection::{CacheSize, Instrumentation}; use diesel::query_builder::{AsQuery, QueryFragment, QueryId}; -use diesel::result::Error; use diesel::row::Row; use diesel::{ConnectionResult, QueryResult}; -use futures_util::{Future, Stream}; +use futures_core::future::BoxFuture; +use futures_core::Stream; +use futures_util::FutureExt; use std::fmt::Debug; +use std::future::Future; pub use scoped_futures; use scoped_futures::{ScopedBoxFuture, ScopedFutureExt}; @@ -115,22 +117,16 @@ pub use self::transaction_manager::{AnsiTransactionManager, TransactionManager}; /// Perform simple operations on a backend. /// /// You should likely use [`AsyncConnection`] instead. -#[async_trait::async_trait] pub trait SimpleAsyncConnection { /// Execute multiple SQL statements within the same string. /// /// This function is used to execute migrations, /// which may contain more than one SQL statement. - async fn batch_execute(&mut self, query: &str) -> QueryResult<()>; + fn batch_execute(&mut self, query: &str) -> impl Future> + Send; } -/// An async connection to a database -/// -/// This trait represents a n async database connection. It can be used to query the database through -/// the query dsl provided by diesel, custom extensions or raw sql queries. It essentially mirrors -/// the sync diesel [`Connection`](diesel::connection::Connection) implementation -#[async_trait::async_trait] -pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { +/// Core trait for an async database connection +pub trait AsyncConnectionCore: SimpleAsyncConnection + Send { /// The future returned by `AsyncConnection::execute` type ExecuteFuture<'conn, 'query>: Future> + Send; /// The future returned by `AsyncConnection::load` @@ -143,6 +139,37 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// The backend this type connects to type Backend: Backend; + #[doc(hidden)] + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query; + + #[doc(hidden)] + fn execute_returning_count<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> Self::ExecuteFuture<'conn, 'query> + where + T: QueryFragment + QueryId + 'query; + + // These functions allow the associated types (`ExecuteFuture`, `LoadFuture`, etc.) to + // compile without a `where Self: '_` clause. This is needed the because bound causes + // lifetime issues when using `transaction()` with generic `AsyncConnection`s. + // + // See: https://github.com/rust-lang/rust/issues/87479 + #[doc(hidden)] + fn _silence_lint_on_execute_future(_: Self::ExecuteFuture<'_, '_>) {} + #[doc(hidden)] + fn _silence_lint_on_load_future(_: Self::LoadFuture<'_, '_>) {} +} + +/// An async connection to a database +/// +/// This trait represents an async database connection. It can be used to query the database through +/// the query dsl provided by diesel, custom extensions or raw sql queries. It essentially mirrors +/// the sync diesel [`Connection`](diesel::connection::Connection) implementation +pub trait AsyncConnection: AsyncConnectionCore + Sized { #[doc(hidden)] type TransactionManager: TransactionManager; @@ -151,7 +178,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// The argument to this method and the method's behavior varies by backend. /// See the documentation for that backend's connection class /// for details about what it accepts and how it behaves. - async fn establish(database_url: &str) -> ConnectionResult; + fn establish(database_url: &str) -> impl Future> + Send; /// Executes the given function inside of a database transaction /// @@ -230,34 +257,44 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// # Ok(()) /// # } /// ``` - async fn transaction<'a, R, E, F>(&mut self, callback: F) -> Result + fn transaction<'a, 'conn, R, E, F>( + &'conn mut self, + callback: F, + ) -> BoxFuture<'conn, Result> + // we cannot use `impl Trait` here due to bugs in rustc + // https://github.com/rust-lang/rust/issues/100013 + //impl Future> + Send + 'async_trait where F: for<'r> FnOnce(&'r mut Self) -> ScopedBoxFuture<'a, 'r, Result> + Send + 'a, E: From + Send + 'a, R: Send + 'a, + 'a: 'conn, { - Self::TransactionManager::transaction(self, callback).await + Self::TransactionManager::transaction(self, callback).boxed() } /// Creates a transaction that will never be committed. This is useful for /// tests. Panics if called while inside of a transaction or /// if called with a connection containing a broken transaction - async fn begin_test_transaction(&mut self) -> QueryResult<()> { + fn begin_test_transaction(&mut self) -> impl Future> + Send { use diesel::connection::TransactionManagerStatus; - match Self::TransactionManager::transaction_manager_status_mut(self) { - TransactionManagerStatus::Valid(valid_status) => { - assert_eq!(None, valid_status.transaction_depth()) - } - TransactionManagerStatus::InError => panic!("Transaction manager in error"), - }; - Self::TransactionManager::begin_transaction(self).await?; - // set the test transaction flag - // to prevent that this connection gets dropped in connection pools - // Tests commonly set the poolsize to 1 and use `begin_test_transaction` - // to prevent modifications to the schema - Self::TransactionManager::transaction_manager_status_mut(self).set_test_transaction_flag(); - Ok(()) + async { + match Self::TransactionManager::transaction_manager_status_mut(self) { + TransactionManagerStatus::Valid(valid_status) => { + assert_eq!(None, valid_status.transaction_depth()) + } + TransactionManagerStatus::InError => panic!("Transaction manager in error"), + }; + Self::TransactionManager::begin_transaction(self).await?; + // set the test transaction flag + // to prevent that this connection gets dropped in connection pools + // Tests commonly set the poolsize to 1 and use `begin_test_transaction` + // to prevent modifications to the schema + Self::TransactionManager::transaction_manager_status_mut(self) + .set_test_transaction_flag(); + Ok(()) + } } /// Executes the given function inside a transaction, but does not commit @@ -297,61 +334,46 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// # Ok(()) /// # } /// ``` - async fn test_transaction<'a, R, E, F>(&'a mut self, f: F) -> R + fn test_transaction<'conn, 'a, R, E, F>( + &'conn mut self, + f: F, + ) -> impl Future + Send + 'conn where F: for<'r> FnOnce(&'r mut Self) -> ScopedBoxFuture<'a, 'r, Result> + Send + 'a, E: Debug + Send + 'a, R: Send + 'a, - Self: 'a, + 'a: 'conn, { use futures_util::TryFutureExt; - - let mut user_result = None; - let _ = self - .transaction::(|c| { - f(c).map_err(|_| Error::RollbackTransaction) - .and_then(|r| { - user_result = Some(r); - futures_util::future::ready(Err(Error::RollbackTransaction)) - }) - .scope_boxed() - }) - .await; - user_result.expect("Transaction did not succeed") + let (user_result_tx, user_result_rx) = std::sync::mpsc::channel(); + self.transaction::(move |conn| { + f(conn) + .map_err(|_| diesel::result::Error::RollbackTransaction) + .and_then(move |r| { + let _ = user_result_tx.send(r); + std::future::ready(Err(diesel::result::Error::RollbackTransaction)) + }) + .scope_boxed() + }) + .then(move |_r| { + let r = user_result_rx + .try_recv() + .expect("Transaction did not succeed"); + std::future::ready(r) + }) } - #[doc(hidden)] - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> - where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query; - - #[doc(hidden)] - fn execute_returning_count<'conn, 'query, T>( - &'conn mut self, - source: T, - ) -> Self::ExecuteFuture<'conn, 'query> - where - T: QueryFragment + QueryId + 'query; - #[doc(hidden)] fn transaction_state( &mut self, ) -> &mut >::TransactionStateData; - // These functions allow the associated types (`ExecuteFuture`, `LoadFuture`, etc.) to - // compile without a `where Self: '_` clause. This is needed the because bound causes - // lifetime issues when using `transaction()` with generic `AsyncConnection`s. - // - // See: https://github.com/rust-lang/rust/issues/87479 - #[doc(hidden)] - fn _silence_lint_on_execute_future(_: Self::ExecuteFuture<'_, '_>) {} - #[doc(hidden)] - fn _silence_lint_on_load_future(_: Self::LoadFuture<'_, '_>) {} - #[doc(hidden)] fn instrumentation(&mut self) -> &mut dyn Instrumentation; /// Set a specific [`Instrumentation`] implementation for this connection fn set_instrumentation(&mut self, instrumentation: impl Instrumentation); + + /// Set the prepared statement cache size to [`CacheSize`] for this connection + fn set_prepared_statement_cache_size(&mut self, size: CacheSize); } diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index 9158f62..1d44650 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -1,19 +1,23 @@ -use crate::stmt_cache::{PrepareCallback, StmtCache}; -use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; -use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey}; -use diesel::connection::Instrumentation; -use diesel::connection::InstrumentationEvent; +use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper}; +use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection}; +use diesel::connection::statement_cache::{ + MaybeCached, QueryFragmentForCachedStatement, StatementCache, +}; use diesel::connection::StrQueryHelper; +use diesel::connection::{CacheSize, Instrumentation}; +use diesel::connection::{DynInstrumentation, InstrumentationEvent}; use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType}; use diesel::query_builder::QueryBuilder; use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId}; use diesel::result::{ConnectionError, ConnectionResult}; use diesel::QueryResult; -use futures_util::future::BoxFuture; -use futures_util::stream::{self, BoxStream}; -use futures_util::{Future, FutureExt, StreamExt, TryStreamExt}; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::stream; +use futures_util::{FutureExt, StreamExt, TryStreamExt}; use mysql_async::prelude::Queryable; use mysql_async::{Opts, OptsBuilder, Statement}; +use std::future::Future; mod error_helper; mod row; @@ -27,12 +31,11 @@ use self::serialize::ToSqlHelper; /// `mysql://[user[:password]@]host/database_name` pub struct AsyncMysqlConnection { conn: mysql_async::Conn, - stmt_cache: StmtCache, + stmt_cache: StatementCache, transaction_manager: AnsiTransactionManager, - instrumentation: std::sync::Mutex>>, + instrumentation: DynInstrumentation, } -#[async_trait::async_trait] impl SimpleAsyncConnection for AsyncMysqlConnection { async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> { self.instrumentation() @@ -61,31 +64,13 @@ const CONNECTION_SETUP_QUERIES: &[&str] = &[ "SET character_set_results = 'utf8mb4'", ]; -#[async_trait::async_trait] -impl AsyncConnection for AsyncMysqlConnection { +impl AsyncConnectionCore for AsyncMysqlConnection { type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult>; type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult>>; type Stream<'conn, 'query> = BoxStream<'conn, QueryResult>>; type Row<'conn, 'query> = MysqlRow; type Backend = Mysql; - type TransactionManager = AnsiTransactionManager; - - async fn establish(database_url: &str) -> diesel::ConnectionResult { - let mut instrumentation = diesel::connection::get_default_instrumentation(); - instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( - database_url, - )); - let r = Self::establish_connection_inner(database_url).await; - instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection( - database_url, - r.as_ref().err(), - )); - let mut conn = r?; - conn.instrumentation = std::sync::Mutex::new(instrumentation); - Ok(conn) - } - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where T: diesel::query_builder::AsQuery, @@ -171,22 +156,40 @@ impl AsyncConnection for AsyncMysqlConnection { .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e))) }) } +} + +impl AsyncConnection for AsyncMysqlConnection { + type TransactionManager = AnsiTransactionManager; + + async fn establish(database_url: &str) -> diesel::ConnectionResult { + let mut instrumentation = DynInstrumentation::default_instrumentation(); + instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( + database_url, + )); + let r = Self::establish_connection_inner(database_url).await; + instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection( + database_url, + r.as_ref().err(), + )); + let mut conn = r?; + conn.instrumentation = instrumentation; + Ok(conn) + } fn transaction_state(&mut self) -> &mut AnsiTransactionManager { &mut self.transaction_manager } fn instrumentation(&mut self) -> &mut dyn Instrumentation { - self.instrumentation - .get_mut() - .unwrap_or_else(|p| p.into_inner()) + &mut *self.instrumentation } fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { - *self - .instrumentation - .get_mut() - .unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation)); + self.instrumentation = instrumentation.into(); + } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.stmt_cache.set_cache_size(size); } } @@ -207,17 +210,24 @@ fn update_transaction_manager_status( query_result } -#[async_trait::async_trait] -impl PrepareCallback for &'_ mut mysql_async::Conn { - async fn prepare( - self, - sql: &str, - _metadata: &[MysqlType], - _is_for_cache: diesel::connection::statement_cache::PrepareForCache, - ) -> QueryResult<(Statement, Self)> { - let s = self.prep(sql).await.map_err(ErrorHelper)?; - Ok((s, self)) - } +fn prepare_statement_helper<'a>( + conn: &'a mut mysql_async::Conn, + sql: &str, + _is_for_cache: diesel::connection::statement_cache::PrepareForCache, + _metadata: &[MysqlType], +) -> CallbackHelper> + Send> +{ + // ideally we wouldn't clone the SQL string here + // but as we usually cache statements anyway + // this is a fixed one time const + // + // The probleme with not cloning it is that we then cannot express + // the right result lifetime anymore (at least not easily) + let sql = sql.to_owned(); + CallbackHelper(async move { + let s = conn.prep(sql).await.map_err(ErrorHelper)?; + Ok((s, conn)) + }) } impl AsyncMysqlConnection { @@ -229,11 +239,9 @@ impl AsyncMysqlConnection { use crate::run_query_dsl::RunQueryDsl; let mut conn = AsyncMysqlConnection { conn, - stmt_cache: StmtCache::new(), + stmt_cache: StatementCache::new(), transaction_manager: AnsiTransactionManager::default(), - instrumentation: std::sync::Mutex::new( - diesel::connection::get_default_instrumentation(), - ), + instrumentation: DynInstrumentation::default_instrumentation(), }; for stmt in CONNECTION_SETUP_QUERIES { @@ -286,36 +294,29 @@ impl AsyncMysqlConnection { } = bind_collector?; let is_safe_to_cache_prepared = is_safe_to_cache_prepared?; let sql = sql?; + let helper = QueryFragmentHelper { + sql, + safe_to_cache: is_safe_to_cache_prepared, + }; let inner = async { - let cache_key = if let Some(query_id) = query_id { - StatementCacheKey::Type(query_id) - } else { - StatementCacheKey::Sql { - sql: sql.clone(), - bind_types: metadata.clone(), - } - }; - let (stmt, conn) = stmt_cache - .cached_prepared_statement( - cache_key, - sql.clone(), - is_safe_to_cache_prepared, + .cached_statement_non_generic( + query_id, + &helper, + &Mysql, &metadata, conn, - instrumentation, + prepare_statement_helper, + &mut **instrumentation, ) .await?; callback(conn, stmt, ToSqlHelper { metadata, binds }).await }; let r = update_transaction_manager_status(inner.await, transaction_manager); - instrumentation - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .on_connection_event(InstrumentationEvent::finish_query( - &StrQueryHelper::new(&sql), - r.as_ref().err(), - )); + instrumentation.on_connection_event(InstrumentationEvent::finish_query( + &StrQueryHelper::new(&helper.sql), + r.as_ref().err(), + )); r } .boxed() @@ -370,9 +371,9 @@ impl AsyncMysqlConnection { Ok(AsyncMysqlConnection { conn, - stmt_cache: StmtCache::new(), + stmt_cache: StatementCache::new(), transaction_manager: AnsiTransactionManager::default(), - instrumentation: std::sync::Mutex::new(None), + instrumentation: DynInstrumentation::none(), }) } } @@ -427,3 +428,13 @@ mod tests { } } } + +impl QueryFragmentForCachedStatement for QueryFragmentHelper { + fn construct_sql(&self, _backend: &Mysql) -> QueryResult { + Ok(self.sql.clone()) + } + + fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult { + Ok(self.safe_to_cache) + } +} diff --git a/src/mysql/row.rs b/src/mysql/row.rs index 5ed5cfc..d049c40 100644 --- a/src/mysql/row.rs +++ b/src/mysql/row.rs @@ -37,7 +37,11 @@ impl RowSealed for MysqlRow {} impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow { type InnerPartialRow = Self; - type Field<'b> = MysqlField<'b> where Self: 'b, 'a: 'b; + type Field<'b> + = MysqlField<'b> + where + Self: 'b, + 'a: 'b; fn field_count(&self) -> usize { self.0.columns_ref().len() @@ -117,7 +121,7 @@ impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow { Some(field) } - fn partial_row(&self, range: std::ops::Range) -> PartialRow { + fn partial_row(&self, range: std::ops::Range) -> PartialRow<'_, Self::InnerPartialRow> { PartialRow::new(self, range) } } @@ -128,7 +132,7 @@ pub struct MysqlField<'a> { name: Cow<'a, str>, } -impl<'a> diesel::row::Field<'a, Mysql> for MysqlField<'_> { +impl diesel::row::Field<'_, Mysql> for MysqlField<'_> { fn field_name(&self) -> Option<&str> { Some(&*self.name) } @@ -225,6 +229,7 @@ fn convert_type(column_type: ColumnType, column_flags: ColumnFlags) -> MysqlType | ColumnType::MYSQL_TYPE_UNKNOWN | ColumnType::MYSQL_TYPE_ENUM | ColumnType::MYSQL_TYPE_SET + | ColumnType::MYSQL_TYPE_VECTOR | ColumnType::MYSQL_TYPE_GEOMETRY => { unimplemented!("Hit an unsupported type: {:?}", column_type) } diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 2ee7145..03e50ec 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -7,12 +7,14 @@ use self::error_helper::ErrorHelper; use self::row::PgRow; use self::serialize::ToSqlHelper; -use crate::stmt_cache::{PrepareCallback, StmtCache}; -use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; -use diesel::connection::statement_cache::{PrepareForCache, StatementCacheKey}; -use diesel::connection::Instrumentation; -use diesel::connection::InstrumentationEvent; +use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper}; +use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection}; +use diesel::connection::statement_cache::{ + PrepareForCache, QueryFragmentForCachedStatement, StatementCache, +}; use diesel::connection::StrQueryHelper; +use diesel::connection::{CacheSize, Instrumentation}; +use diesel::connection::{DynInstrumentation, InstrumentationEvent}; use diesel::pg::{ Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata, }; @@ -20,16 +22,16 @@ use diesel::query_builder::bind_collector::RawBytesBindCollector; use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId}; use diesel::result::{DatabaseErrorKind, Error}; use diesel::{ConnectionError, ConnectionResult, QueryResult}; -use futures_util::future::BoxFuture; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; use futures_util::future::Either; -use futures_util::stream::{BoxStream, TryStreamExt}; +use futures_util::stream::TryStreamExt; use futures_util::TryFutureExt; -use futures_util::{Future, FutureExt, StreamExt}; +use futures_util::{FutureExt, StreamExt}; use std::collections::{HashMap, HashSet}; +use std::future::Future; use std::sync::Arc; -use tokio::sync::broadcast; -use tokio::sync::oneshot; -use tokio::sync::Mutex; +use tokio::sync::{broadcast, mpsc, oneshot, Mutex}; use tokio_postgres::types::ToSql; use tokio_postgres::types::Type; use tokio_postgres::Statement; @@ -110,6 +112,48 @@ const FAKE_OID: u32 = 0; /// # } /// ``` /// +/// For more complex cases, an immutable reference to the connection need to be used: +/// ```rust +/// # include!("../doctest_setup.rs"); +/// use diesel_async::RunQueryDsl; +/// +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() { +/// # run_test().await.unwrap(); +/// # } +/// # +/// # async fn run_test() -> QueryResult<()> { +/// # use diesel::sql_types::{Text, Integer}; +/// # let conn = &mut establish_connection().await; +/// # +/// async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { +/// let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); +/// let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); +/// +/// futures_util::try_join!(f1, f2) +/// } +/// +/// async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { +/// let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); +/// let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); +/// +/// futures_util::try_join!(f3, f4) +/// } +/// +/// let f12 = fn12(&conn); +/// let f34 = fn34(&conn); +/// +/// let ((r1, r2), (r3, r4)) = futures_util::try_join!(f12, f34).unwrap(); +/// +/// assert_eq!(r1, 1); +/// assert_eq!(r2, 2); +/// assert_eq!(r3, 3); +/// assert_eq!(r4, 4); +/// # Ok(()) +/// # } +/// ``` +/// /// ## TLS /// /// Connections created by [`AsyncPgConnection::establish`] do not support TLS. @@ -122,17 +166,23 @@ const FAKE_OID: u32 = 0; /// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/ pub struct AsyncPgConnection { conn: Arc, - stmt_cache: Arc>>, + stmt_cache: Arc>>, transaction_state: Arc>, metadata_cache: Arc>, connection_future: Option>>, + notification_rx: Option>>, shutdown_channel: Option>, // a sync mutex is fine here as we only hold it for a really short time - instrumentation: Arc>>>, + instrumentation: Arc>, } -#[async_trait::async_trait] impl SimpleAsyncConnection for AsyncPgConnection { + async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { + SimpleAsyncConnection::batch_execute(&mut &*self, query).await + } +} + +impl SimpleAsyncConnection for &AsyncPgConnection { async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new( query, @@ -143,7 +193,12 @@ impl SimpleAsyncConnection for AsyncPgConnection { .batch_execute(query) .map_err(ErrorHelper) .map_err(Into::into); + let r = drive_future(connection_future, batch_execute).await; + let r = { + let mut transaction_manager = self.transaction_state.lock().await; + update_transaction_manager_status(r, &mut transaction_manager) + }; self.record_instrumentation(InstrumentationEvent::finish_query( &StrQueryHelper::new(query), r.as_ref().err(), @@ -152,17 +207,73 @@ impl SimpleAsyncConnection for AsyncPgConnection { } } -#[async_trait::async_trait] -impl AsyncConnection for AsyncPgConnection { +impl AsyncConnectionCore for AsyncPgConnection { type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; type Stream<'conn, 'query> = BoxStream<'static, QueryResult>; type Row<'conn, 'query> = PgRow; type Backend = diesel::pg::Pg; + + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, + { + AsyncConnectionCore::load(&mut &*self, source) + } + + fn execute_returning_count<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> Self::ExecuteFuture<'conn, 'query> + where + T: QueryFragment + QueryId + 'query, + { + AsyncConnectionCore::execute_returning_count(&mut &*self, source) + } +} + +impl AsyncConnectionCore for &AsyncPgConnection { + type LoadFuture<'conn, 'query> = + ::LoadFuture<'conn, 'query>; + + type ExecuteFuture<'conn, 'query> = + ::ExecuteFuture<'conn, 'query>; + + type Stream<'conn, 'query> = ::Stream<'conn, 'query>; + + type Row<'conn, 'query> = ::Row<'conn, 'query>; + + type Backend = ::Backend; + + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, + { + let query = source.as_query(); + let load_future = self.with_prepared_statement(query, load_prepared); + + self.run_with_connection_future(load_future) + } + + fn execute_returning_count<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> Self::ExecuteFuture<'conn, 'query> + where + T: QueryFragment + QueryId + 'query, + { + let execute = self.with_prepared_statement(source, execute_prepared); + self.run_with_connection_future(execute) + } +} + +impl AsyncConnection for AsyncPgConnection { type TransactionManager = AnsiTransactionManager; async fn establish(database_url: &str) -> ConnectionResult { - let mut instrumentation = diesel::connection::get_default_instrumentation(); + let mut instrumentation = DynInstrumentation::default_instrumentation(); instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( database_url, )); @@ -171,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection { .await .map_err(ErrorHelper)?; - let (error_rx, shutdown_tx) = drive_connection(connection); + let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection); let r = Self::setup( client, Some(error_rx), + Some(notification_rx), Some(shutdown_tx), Arc::clone(&instrumentation), ) @@ -191,28 +303,6 @@ impl AsyncConnection for AsyncPgConnection { r } - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> - where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query, - { - let query = source.as_query(); - let load_future = self.with_prepared_statement(query, load_prepared); - - self.run_with_connection_future(load_future) - } - - fn execute_returning_count<'conn, 'query, T>( - &'conn mut self, - source: T, - ) -> Self::ExecuteFuture<'conn, 'query> - where - T: QueryFragment + QueryId + 'query, - { - let execute = self.with_prepared_statement(source, execute_prepared); - self.run_with_connection_future(execute) - } - fn transaction_state(&mut self) -> &mut AnsiTransactionManager { // there should be no other pending future when this is called // that means there is only one instance of this arc and @@ -229,14 +319,25 @@ impl AsyncConnection for AsyncPgConnection { // that means there is only one instance of this arc and // we can simply access the inner data if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) { - instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()) + &mut **(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())) } else { panic!("Cannot access shared instrumentation") } } fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { - self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation)))); + self.instrumentation = Arc::new(std::sync::Mutex::new(instrumentation.into())); + } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) { + cache.get_mut().set_cache_size(size) + } else { + panic!("Cannot access shared statement cache") + } } } @@ -286,32 +387,42 @@ fn update_transaction_manager_status( if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) = query_result { - transaction_manager - .status - .set_requires_rollback_maybe_up_to_top_level(true) + if !transaction_manager.is_commit { + transaction_manager + .status + .set_requires_rollback_maybe_up_to_top_level(true); + } } query_result } -#[async_trait::async_trait] -impl PrepareCallback for Arc { - async fn prepare( - self, - sql: &str, - metadata: &[PgTypeMetadata], - _is_for_cache: PrepareForCache, - ) -> QueryResult<(Statement, Self)> { - let bind_types = metadata - .iter() - .map(type_from_oid) - .collect::>>()?; - - let stmt = self - .prepare_typed(sql, &bind_types) +fn prepare_statement_helper( + conn: Arc, + sql: &str, + _is_for_cache: PrepareForCache, + metadata: &[PgTypeMetadata], +) -> CallbackHelper< + impl Future)>> + Send, +> { + let bind_types = metadata + .iter() + .map(type_from_oid) + .collect::>>(); + // ideally we wouldn't clone the SQL string here + // but as we usually cache statements anyway + // this is a fixed one time const + // + // The probleme with not cloning it is that we then cannot express + // the right result lifetime anymore (at least not easily) + let sql = sql.to_string(); + CallbackHelper(async move { + let bind_types = bind_types?; + let stmt = conn + .prepare_typed(&sql, &bind_types) .await .map_err(ErrorHelper); - Ok((stmt?, self)) - } + Ok((stmt?, conn)) + }) } fn type_from_oid(t: &PgTypeMetadata) -> QueryResult { @@ -358,7 +469,7 @@ impl AsyncPgConnection { /// .await /// # } /// ``` - pub fn build_transaction(&mut self) -> TransactionBuilder { + pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> { TransactionBuilder::new(self) } @@ -368,8 +479,9 @@ impl AsyncPgConnection { conn, None, None, + None, Arc::new(std::sync::Mutex::new( - diesel::connection::get_default_instrumentation(), + DynInstrumentation::default_instrumentation(), )), ) .await @@ -384,14 +496,15 @@ impl AsyncPgConnection { where S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, { - let (error_rx, shutdown_tx) = drive_connection(conn); + let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn); Self::setup( client, Some(error_rx), + Some(notification_rx), Some(shutdown_tx), Arc::new(std::sync::Mutex::new( - diesel::connection::get_default_instrumentation(), + DynInstrumentation::default_instrumentation(), )), ) .await @@ -400,15 +513,17 @@ impl AsyncPgConnection { async fn setup( conn: tokio_postgres::Client, connection_future: Option>>, + notification_rx: Option>>, shutdown_channel: Option>, - instrumentation: Arc>>>, + instrumentation: Arc>, ) -> ConnectionResult { let mut conn = Self { conn: Arc::new(conn), - stmt_cache: Arc::new(Mutex::new(StmtCache::new())), + stmt_cache: Arc::new(Mutex::new(StatementCache::new())), transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), connection_future, + notification_rx, shutdown_channel, instrumentation, }; @@ -426,10 +541,11 @@ impl AsyncPgConnection { async fn set_config_options(&mut self) -> QueryResult<()> { use crate::run_query_dsl::RunQueryDsl; - futures_util::try_join!( + futures_util::future::try_join( diesel::sql_query("SET TIME ZONE 'UTC'").execute(self), diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self), - )?; + ) + .await?; Ok(()) } @@ -442,7 +558,7 @@ impl AsyncPgConnection { } fn with_prepared_statement<'a, T, F, R>( - &mut self, + &self, query: T, callback: fn(Arc, Statement, Vec) -> F, ) -> BoxFuture<'a, QueryResult> @@ -477,7 +593,7 @@ impl AsyncPgConnection { } fn with_prepared_statement_after_sql_built<'a, F, R>( - &mut self, + &self, callback: fn(Arc, Statement, Vec) -> F, is_safe_to_cache_prepared: QueryResult, query_id: Option, @@ -559,23 +675,27 @@ impl AsyncPgConnection { })?; } } - let key = match query_id { - Some(id) => StatementCacheKey::Type(id), - None => StatementCacheKey::Sql { - sql: sql.clone(), - bind_types: bind_collector.metadata.clone(), - }, - }; let stmt = { let mut stmt_cache = stmt_cache.lock().await; + let helper = QueryFragmentHelper { + sql: sql.clone(), + safe_to_cache: is_safe_to_cache_prepared, + }; + let instrumentation = Arc::clone(&instrumentation); stmt_cache - .cached_prepared_statement( - key, - sql.clone(), - is_safe_to_cache_prepared, + .cached_statement_non_generic( + query_id, + &helper, + &Pg, &bind_collector.metadata, raw_connection.clone(), - &instrumentation + prepare_statement_helper, + &mut move |event: InstrumentationEvent<'_>| { + // we wrap this lock into another callback to prevent locking + // the instrumentation longer than necessary + instrumentation.lock().unwrap_or_else(|e| e.into_inner()) + .on_connection_event(event); + }, ) .await? .0 @@ -612,6 +732,58 @@ impl AsyncPgConnection { .unwrap_or_else(|p| p.into_inner()) .on_connection_event(event); } + + /// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][] + /// + /// The returned stream yields all notifications received by the connection, not only notifications received + /// after calling the function. The returned stream will never close, so no notifications will just result + /// in a pending state. + /// + /// If there's no connection available to poll, the stream will yield no notifications and be pending forever. + /// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor. + /// + /// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html + /// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html + /// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection + /// [`try_from`]: crate::pg::AsyncPgConnection::try_from + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # use scoped_futures::ScopedFutureExt; + /// # + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// # run_test().await.unwrap(); + /// # } + /// # + /// # async fn run_test() -> QueryResult<()> { + /// # use diesel_async::RunQueryDsl; + /// # use futures_util::StreamExt; + /// # let conn = &mut connection_no_transaction().await; + /// // register the notifications channel we want to receive notifications for + /// diesel::sql_query("LISTEN example_channel").execute(conn).await?; + /// // send some notification (usually done from a different connection/thread/application) + /// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?; + /// + /// let mut notifications = std::pin::pin!(conn.notifications_stream()); + /// let mut notification = notifications.next().await.unwrap().unwrap(); + /// + /// assert_eq!(notification.channel, "example_channel"); + /// assert_eq!(notification.payload, "additional data"); + /// println!("Notification received from process with id {}", notification.process_id); + /// # Ok(()) + /// # } + /// ``` + pub fn notifications_stream( + &mut self, + ) -> impl futures_core::Stream> + '_ { + match &mut self.notification_rx { + None => Either::Left(futures_util::stream::pending()), + Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async { + rx.recv().await.map(move |item| (item, rx)) + })), + } + } } struct BindData { @@ -857,27 +1029,44 @@ async fn drive_future( } fn drive_connection( - conn: tokio_postgres::Connection, + mut conn: tokio_postgres::Connection, ) -> ( broadcast::Receiver>, + mpsc::UnboundedReceiver>, oneshot::Sender<()>, ) where S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, { let (error_tx, error_rx) = tokio::sync::broadcast::channel(1); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel(); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel(); + let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx)); tokio::spawn(async move { - match futures_util::future::select(shutdown_rx, conn).await { - Either::Left(_) | Either::Right((Ok(_), _)) => {} - Either::Right((Err(e), _)) => { - let _ = error_tx.send(Arc::new(e)); + loop { + match futures_util::future::select(&mut shutdown_rx, conn.next()).await { + Either::Left(_) | Either::Right((None, _)) => break, + Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => { + let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification { + process_id: notif.process_id(), + channel: notif.channel().to_owned(), + payload: notif.payload().to_owned(), + })); + } + Either::Right((Some(Ok(_)), _)) => {} + Either::Right((Some(Err(e)), _)) => { + let e = Arc::new(e); + let _: Result<_, _> = error_tx.send(e.clone()); + let _: Result<_, _> = + notification_tx.send(Err(error_helper::from_tokio_postgres_error(e))); + break; + } } } }); - (error_rx, shutdown_tx) + (error_rx, notification_rx, shutdown_tx) } #[cfg(any( @@ -894,17 +1083,31 @@ impl crate::pooled_connection::PoolableConnection for AsyncPgConnection { } } +impl QueryFragmentForCachedStatement for QueryFragmentHelper { + fn construct_sql(&self, _backend: &Pg) -> QueryResult { + Ok(self.sql.clone()) + } + + fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult { + Ok(self.safe_to_cache) + } +} + #[cfg(test)] mod tests { use super::*; use crate::run_query_dsl::RunQueryDsl; use diesel::sql_types::Integer; use diesel::IntoSql; + use futures_util::future::try_join; + use futures_util::try_join; + use scoped_futures::ScopedFutureExt; #[tokio::test] async fn pipelining() { let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests"); + let mut conn = crate::AsyncPgConnection::establish(&database_url) .await .unwrap(); @@ -915,9 +1118,100 @@ mod tests { let f1 = q1.get_result::(&mut conn); let f2 = q2.get_result::(&mut conn); - let (r1, r2) = futures_util::try_join!(f1, f2).unwrap(); + let (r1, r2) = try_join!(f1, f2).unwrap(); assert_eq!(r1, 1); assert_eq!(r2, 2); } + + #[tokio::test] + async fn pipelining_with_composed_futures() { + let database_url = + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests"); + + let conn = crate::AsyncPgConnection::establish(&database_url) + .await + .unwrap(); + + async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); + let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); + + try_join!(f1, f2) + } + + async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); + let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); + + try_join!(f3, f4) + } + + let f12 = fn12(&conn); + let f34 = fn34(&conn); + + let ((r1, r2), (r3, r4)) = try_join!(f12, f34).unwrap(); + + assert_eq!(r1, 1); + assert_eq!(r2, 2); + assert_eq!(r3, 3); + assert_eq!(r4, 4); + } + + #[tokio::test] + async fn pipelining_with_composed_futures_and_transaction() { + let database_url = + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests"); + + let mut conn = crate::AsyncPgConnection::establish(&database_url) + .await + .unwrap(); + + fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future + Send + 'a { + t + } + + async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); + let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); + + erase(try_join(f1, f2)).await + } + + async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); + let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); + + try_join(f3, f4).boxed().await + } + + async fn fn56(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f5 = diesel::select(5_i32.into_sql::()).get_result::(&mut conn); + let f6 = diesel::select(6_i32.into_sql::()).get_result::(&mut conn); + + try_join!(f5.boxed(), f6.boxed()) + } + + conn.transaction(|conn| { + async move { + let f12 = fn12(conn); + let f34 = fn34(conn); + let f56 = fn56(conn); + + let ((r1, r2), (r3, r4), (r5, r6)) = try_join!(f12, f34, f56).unwrap(); + + assert_eq!(r1, 1); + assert_eq!(r2, 2); + assert_eq!(r3, 3); + assert_eq!(r4, 4); + assert_eq!(r5, 5); + assert_eq!(r6, 6); + + QueryResult::<_>::Ok(()) + } + .scope_boxed() + }) + .await + .unwrap(); + } } diff --git a/src/pg/row.rs b/src/pg/row.rs index b1dafdb..c0c0be7 100644 --- a/src/pg/row.rs +++ b/src/pg/row.rs @@ -16,7 +16,11 @@ impl RowSealed for PgRow {} impl<'a> diesel::row::Row<'a, diesel::pg::Pg> for PgRow { type InnerPartialRow = Self; - type Field<'b> = PgField<'b> where Self: 'b, 'a: 'b; + type Field<'b> + = PgField<'b> + where + Self: 'b, + 'a: 'b; fn field_count(&self) -> usize { self.row.len() @@ -37,7 +41,7 @@ impl<'a> diesel::row::Row<'a, diesel::pg::Pg> for PgRow { fn partial_row( &self, range: std::ops::Range, - ) -> diesel::row::PartialRow { + ) -> diesel::row::PartialRow<'_, Self::InnerPartialRow> { PartialRow::new(self, range) } } diff --git a/src/pg/transaction_builder.rs b/src/pg/transaction_builder.rs index 1096433..095e7bd 100644 --- a/src/pg/transaction_builder.rs +++ b/src/pg/transaction_builder.rs @@ -310,7 +310,7 @@ where } } -impl<'a, C> QueryFragment for TransactionBuilder<'a, C> { +impl QueryFragment for TransactionBuilder<'_, C> { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { out.push_sql("BEGIN TRANSACTION"); if let Some(ref isolation_level) = self.isolation_level { diff --git a/src/pooled_connection/bb8.rs b/src/pooled_connection/bb8.rs index 8f5eba3..c920dc7 100644 --- a/src/pooled_connection/bb8.rs +++ b/src/pooled_connection/bb8.rs @@ -50,7 +50,6 @@ //! # Ok(()) //! # } //! ``` - use super::{AsyncDieselConnectionManager, PoolError, PoolableConnection}; use bb8::ManageConnection; use diesel::query_builder::QueryFragment; @@ -65,7 +64,6 @@ pub type PooledConnection<'a, C> = bb8::PooledConnection<'a, AsyncDieselConnecti /// Type alias for using [`bb8::RunError`] with [`diesel-async`] pub type RunError = bb8::RunError; -#[async_trait::async_trait] impl ManageConnection for AsyncDieselConnectionManager where C: PoolableConnection + 'static, diff --git a/src/pooled_connection/mod.rs b/src/pooled_connection/mod.rs index 21471b1..cbe9f60 100644 --- a/src/pooled_connection/mod.rs +++ b/src/pooled_connection/mod.rs @@ -5,14 +5,16 @@ //! * [deadpool](self::deadpool) //! * [bb8](self::bb8) //! * [mobc](self::mobc) -use crate::{AsyncConnection, SimpleAsyncConnection}; +use crate::{AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection}; use crate::{TransactionManager, UpdateAndFetchResults}; use diesel::associations::HasTable; -use diesel::connection::Instrumentation; +use diesel::connection::{CacheSize, Instrumentation}; use diesel::QueryResult; -use futures_util::{future, FutureExt}; +use futures_core::future::BoxFuture; +use futures_util::FutureExt; use std::borrow::Cow; use std::fmt; +use std::future::Future; use std::ops::DerefMut; #[cfg(feature = "bb8")] @@ -45,11 +47,10 @@ impl std::error::Error for PoolError {} /// Type of the custom setup closure passed to [`ManagerConfig::custom_setup`] pub type SetupCallback = - Box future::BoxFuture> + Send + Sync>; + Box BoxFuture> + Send + Sync>; /// Type of the recycle check callback for the [`RecyclingMethod::CustomFunction`] variant -pub type RecycleCheckCallback = - dyn Fn(&mut C) -> future::BoxFuture> + Send + Sync; +pub type RecycleCheckCallback = dyn Fn(&mut C) -> BoxFuture> + Send + Sync; /// Possible methods of how a connection is recycled. #[derive(Default)] @@ -164,7 +165,6 @@ where } } -#[async_trait::async_trait] impl SimpleAsyncConnection for C where C: DerefMut + Send, @@ -176,28 +176,18 @@ where } } -#[async_trait::async_trait] -impl AsyncConnection for C +impl AsyncConnectionCore for C where C: DerefMut + Send, - C::Target: AsyncConnection, + C::Target: AsyncConnectionCore, { type ExecuteFuture<'conn, 'query> = - ::ExecuteFuture<'conn, 'query>; - type LoadFuture<'conn, 'query> = ::LoadFuture<'conn, 'query>; - type Stream<'conn, 'query> = ::Stream<'conn, 'query>; - type Row<'conn, 'query> = ::Row<'conn, 'query>; + ::ExecuteFuture<'conn, 'query>; + type LoadFuture<'conn, 'query> = ::LoadFuture<'conn, 'query>; + type Stream<'conn, 'query> = ::Stream<'conn, 'query>; + type Row<'conn, 'query> = ::Row<'conn, 'query>; - type Backend = ::Backend; - - type TransactionManager = - PoolTransactionManager<::TransactionManager>; - - async fn establish(_database_url: &str) -> diesel::ConnectionResult { - Err(diesel::result::ConnectionError::BadConnection( - String::from("Cannot directly establish a pooled connection"), - )) - } + type Backend = ::Backend; fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where @@ -222,6 +212,21 @@ where let conn = self.deref_mut(); conn.execute_returning_count(source) } +} + +impl AsyncConnection for C +where + C: DerefMut + Send, + C::Target: AsyncConnection, +{ + type TransactionManager = + PoolTransactionManager<::TransactionManager>; + + async fn establish(_database_url: &str) -> diesel::ConnectionResult { + Err(diesel::result::ConnectionError::BadConnection( + String::from("Cannot directly establish a pooled connection"), + )) + } fn transaction_state( &mut self, @@ -241,13 +246,16 @@ where fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.deref_mut().set_instrumentation(instrumentation); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.deref_mut().set_prepared_statement_cache_size(size); + } } #[doc(hidden)] #[allow(missing_debug_implementations)] pub struct PoolTransactionManager(std::marker::PhantomData); -#[async_trait::async_trait] impl TransactionManager for PoolTransactionManager where C: DerefMut + Send, @@ -279,18 +287,22 @@ where } } -#[async_trait::async_trait] impl UpdateAndFetchResults for Conn where Conn: DerefMut + Send, Changes: diesel::prelude::Identifiable + HasTable + Send, Conn::Target: UpdateAndFetchResults, { - async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult + fn update_and_fetch<'conn, 'changes>( + &'conn mut self, + changeset: Changes, + ) -> BoxFuture<'changes, QueryResult> where - Changes: 'async_trait, + Changes: 'changes, + 'conn: 'changes, + Self: 'changes, { - self.deref_mut().update_and_fetch(changeset).await + self.deref_mut().update_and_fetch(changeset) } } @@ -317,13 +329,15 @@ impl diesel::query_builder::Query for CheckConnectionQuery { impl diesel::query_dsl::RunQueryDsl for CheckConnectionQuery {} #[doc(hidden)] -#[async_trait::async_trait] pub trait PoolableConnection: AsyncConnection { /// Check if a connection is still valid /// /// The default implementation will perform a check based on the provided /// recycling method variant - async fn ping(&mut self, config: &RecyclingMethod) -> diesel::QueryResult<()> + fn ping( + &mut self, + config: &RecyclingMethod, + ) -> impl Future> + Send where for<'a> Self: 'a, diesel::dsl::select>: @@ -333,19 +347,21 @@ pub trait PoolableConnection: AsyncConnection { use crate::run_query_dsl::RunQueryDsl; use diesel::IntoSql; - match config { - RecyclingMethod::Fast => Ok(()), - RecyclingMethod::Verified => { - diesel::select(1_i32.into_sql::()) + async move { + match config { + RecyclingMethod::Fast => Ok(()), + RecyclingMethod::Verified => { + diesel::select(1_i32.into_sql::()) + .execute(self) + .await + .map(|_| ()) + } + RecyclingMethod::CustomQuery(query) => diesel::sql_query(query.as_ref()) .execute(self) .await - .map(|_| ()) + .map(|_| ()), + RecyclingMethod::CustomFunction(c) => c(self).await, } - RecyclingMethod::CustomQuery(query) => diesel::sql_query(query.as_ref()) - .execute(self) - .await - .map(|_| ()), - RecyclingMethod::CustomFunction(c) => c(self).await, } } diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index f3767ee..437d2a2 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -1,9 +1,12 @@ -use crate::AsyncConnection; +use crate::AsyncConnectionCore; use diesel::associations::HasTable; use diesel::query_builder::IntoUpdateTarget; use diesel::result::QueryResult; use diesel::AsChangeset; -use futures_util::{future, stream, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures_core::future::BoxFuture; +use futures_core::Stream; +use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +use std::future::Future; use std::pin::Pin; /// The traits used by `QueryDsl`. @@ -28,9 +31,9 @@ pub mod methods { /// to call `execute` from generic code. /// /// [`RunQueryDsl`]: super::RunQueryDsl - pub trait ExecuteDsl::Backend> + pub trait ExecuteDsl::Backend> where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, DB: Backend, { /// Execute this command @@ -44,7 +47,7 @@ pub mod methods { impl ExecuteDsl for T where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, DB: Backend, T: QueryFragment + QueryId + Send, { @@ -66,7 +69,7 @@ pub mod methods { /// to call `load` from generic code. /// /// [`RunQueryDsl`]: super::RunQueryDsl - pub trait LoadQuery<'query, Conn: AsyncConnection, U> { + pub trait LoadQuery<'query, Conn: AsyncConnectionCore, U> { /// The future returned by [`LoadQuery::internal_load`] type LoadFuture<'conn>: Future>> + Send where @@ -82,7 +85,7 @@ pub mod methods { impl<'query, Conn, DB, T, U, ST> LoadQuery<'query, Conn, U> for T where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, U: Send, DB: Backend + 'static, T: AsQuery + Send + 'query, @@ -92,17 +95,21 @@ pub mod methods { DB: QueryMetadata, ST: 'static, { - type LoadFuture<'conn> = future::MapOk< + type LoadFuture<'conn> + = future::MapOk< Conn::LoadFuture<'conn, 'query>, fn(Conn::Stream<'conn, 'query>) -> Self::Stream<'conn>, - > where Conn: 'conn; + > + where + Conn: 'conn; - type Stream<'conn> = stream::Map< + type Stream<'conn> + = stream::Map< Conn::Stream<'conn, 'query>, - fn( - QueryResult>, - ) -> QueryResult, - >where Conn: 'conn; + fn(QueryResult>) -> QueryResult, + > + where + Conn: 'conn; fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> { conn.load(self) @@ -220,7 +227,7 @@ pub trait RunQueryDsl: Sized { /// ``` fn execute<'conn, 'query>(self, conn: &'conn mut Conn) -> Conn::ExecuteFuture<'conn, 'query> where - Conn: AsyncConnection + Send, + Conn: AsyncConnectionCore + Send, Self: methods::ExecuteDsl + 'query, { methods::ExecuteDsl::execute(self, conn) @@ -336,7 +343,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::LoadFuture<'conn, 'query, Self, Conn, U> where U: Send, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { fn collect_result(stream: S) -> stream::TryCollect> @@ -392,7 +399,7 @@ pub trait RunQueryDsl: Sized { /// .await? /// .try_fold(Vec::new(), |mut acc, item| { /// acc.push(item); - /// futures_util::future::ready(Ok(acc)) + /// std::future::ready(Ok(acc)) /// }) /// .await?; /// assert_eq!(vec!["Sean", "Tess"], data); @@ -421,7 +428,7 @@ pub trait RunQueryDsl: Sized { /// .await? /// .try_fold(Vec::new(), |mut acc, item| { /// acc.push(item); - /// futures_util::future::ready(Ok(acc)) + /// std::future::ready(Ok(acc)) /// }) /// .await?; /// let expected_data = vec![ @@ -461,7 +468,7 @@ pub trait RunQueryDsl: Sized { /// .await? /// .try_fold(Vec::new(), |mut acc, item| { /// acc.push(item); - /// futures_util::future::ready(Ok(acc)) + /// std::future::ready(Ok(acc)) /// }) /// .await?; /// let expected_data = vec![ @@ -474,7 +481,7 @@ pub trait RunQueryDsl: Sized { /// ``` fn load_stream<'conn, 'query, U>(self, conn: &'conn mut Conn) -> Self::LoadFuture<'conn> where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, U: 'conn, Self: methods::LoadQuery<'query, Conn, U> + 'query, { @@ -537,7 +544,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::GetResult<'conn, 'query, Self, Conn, U> where U: Send + 'conn, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { #[allow(clippy::type_complexity)] @@ -577,7 +584,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::LoadFuture<'conn, 'query, Self, Conn, U> where U: Send, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { self.load(conn) @@ -633,7 +640,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::GetResult<'conn, 'query, diesel::dsl::Limit, Conn, U> where U: Send + 'conn, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: diesel::query_dsl::methods::LimitDsl, diesel::dsl::Limit: methods::LoadQuery<'query, Conn, U> + Send + 'query, { @@ -695,15 +702,20 @@ impl RunQueryDsl for T {} /// # Ok(()) /// # } /// ``` -#[async_trait::async_trait] pub trait SaveChangesDsl { /// See the trait documentation - async fn save_changes(self, connection: &mut Conn) -> QueryResult + fn save_changes<'life0, 'async_trait, T>( + self, + connection: &'life0 mut Conn, + ) -> impl Future> + Send + 'async_trait where Self: Sized + diesel::prelude::Identifiable, Conn: UpdateAndFetchResults, + T: 'async_trait, + 'life0: 'async_trait, + Self: ::core::marker::Send + 'async_trait, { - connection.update_and_fetch(self).await + connection.update_and_fetch(self) } } @@ -722,58 +734,69 @@ impl SaveChangesDsl for T where /// For implementing this trait for a custom backend: /// * The `Changes` generic parameter represents the changeset that should be stored /// * The `Output` generic parameter represents the type of the response. -#[async_trait::async_trait] -pub trait UpdateAndFetchResults: AsyncConnection +pub trait UpdateAndFetchResults: AsyncConnectionCore where Changes: diesel::prelude::Identifiable + HasTable, { /// See the traits documentation. - async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult + fn update_and_fetch<'conn, 'changes>( + &'conn mut self, + changeset: Changes, + ) -> BoxFuture<'changes, QueryResult> + // cannot use impl future due to rustc bugs + // https://github.com/rust-lang/rust/issues/135619 + //impl Future> + Send + 'changes where - Changes: 'async_trait; + Changes: 'changes, + 'conn: 'changes, + Self: 'changes; } #[cfg(feature = "mysql")] -#[async_trait::async_trait] -impl<'b, Changes, Output> UpdateAndFetchResults for crate::AsyncMysqlConnection +impl<'b, Changes, Output, Tab, V> UpdateAndFetchResults + for crate::AsyncMysqlConnection where - Output: Send, - Changes: Copy + diesel::Identifiable + Send, - Changes: AsChangeset::Table> + IntoUpdateTarget, - Changes::Table: diesel::query_dsl::methods::FindDsl + Send, - Changes::WhereClause: Send, - Changes::Changeset: Send, - Changes::Id: Send, - diesel::dsl::Update: methods::ExecuteDsl, + Output: Send + 'static, + Changes: + Copy + AsChangeset + Send + diesel::associations::Identifiable, + Tab: diesel::Table + diesel::query_dsl::methods::FindDsl + 'b, + diesel::dsl::Find: IntoUpdateTarget
, + diesel::query_builder::UpdateStatement: + diesel::query_builder::AsQuery, + diesel::dsl::Update: methods::ExecuteDsl, + V: Send + 'b, + Changes::Changeset: Send + 'b, + Changes::Id: 'b, + Tab::FromClause: Send, diesel::dsl::Find: - methods::LoadQuery<'b, crate::AsyncMysqlConnection, Output> + Send + 'b, - ::AllColumns: diesel::expression::ValidGrouping<()>, - <::AllColumns as diesel::expression::ValidGrouping<()>>::IsAggregate: diesel::expression::MixedAggregates< - diesel::expression::is_aggregate::No, - Output = diesel::expression::is_aggregate::No, - >, - ::FromClause: Send, + methods::LoadQuery<'b, crate::AsyncMysqlConnection, Output> + Send, { - async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult + fn update_and_fetch<'conn, 'changes>( + &'conn mut self, + changeset: Changes, + ) -> BoxFuture<'changes, QueryResult> where - Changes: 'async_trait, + Changes: 'changes, + Changes::Changeset: 'changes, + 'conn: 'changes, + Self: 'changes, { - use diesel::query_dsl::methods::FindDsl; - - diesel::update(changeset) - .set(changeset) - .execute(self) - .await?; - Changes::table().find(changeset.id()).get_result(self).await + async move { + diesel::update(changeset) + .set(changeset) + .execute(self) + .await?; + Changes::table().find(changeset.id()).get_result(self).await + } + .boxed() } } #[cfg(feature = "postgres")] -#[async_trait::async_trait] impl<'b, Changes, Output, Tab, V> UpdateAndFetchResults for crate::AsyncPgConnection where - Output: Send, + Output: Send + 'static, Changes: Copy + AsChangeset + Send + diesel::associations::Identifiable
, Tab: diesel::Table + diesel::query_dsl::methods::FindDsl + 'b, @@ -785,14 +808,22 @@ where Changes::Changeset: Send + 'b, Tab::FromClause: Send, { - async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult + fn update_and_fetch<'conn, 'changes>( + &'conn mut self, + changeset: Changes, + ) -> BoxFuture<'changes, QueryResult> where - Changes: 'async_trait, - Changes::Changeset: 'async_trait, + Changes: 'changes, + Changes::Changeset: 'changes, + 'conn: 'changes, + Self: 'changes, { - diesel::update(changeset) - .set(changeset) - .get_result(self) - .await + async move { + diesel::update(changeset) + .set(changeset) + .get_result(self) + .await + } + .boxed() } } diff --git a/src/stmt_cache.rs b/src/stmt_cache.rs index 9d6b9af..c2270b8 100644 --- a/src/stmt_cache.rs +++ b/src/stmt_cache.rs @@ -1,91 +1,59 @@ -use std::collections::HashMap; -use std::hash::Hash; - -use diesel::backend::Backend; -use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey}; -use diesel::connection::Instrumentation; -use diesel::connection::InstrumentationEvent; +use diesel::connection::statement_cache::{MaybeCached, StatementCallbackReturnType}; use diesel::QueryResult; -use futures_util::{future, FutureExt}; +use futures_core::future::BoxFuture; +use futures_util::future::Either; +use futures_util::{FutureExt, TryFutureExt}; +use std::future::{self, Future}; -#[derive(Default)] -pub struct StmtCache { - cache: HashMap, S>, -} +pub(crate) struct CallbackHelper(pub(crate) F); -type PrepareFuture<'a, F, S> = future::Either< - future::Ready, F)>>, - future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>, +type PrepareFuture<'a, C, S> = Either< + future::Ready, C)>>, + BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, C)>>, >; -#[async_trait::async_trait] -pub trait PrepareCallback: Sized { - async fn prepare( - self, - sql: &str, - metadata: &[M], - is_for_cache: PrepareForCache, - ) -> QueryResult<(S, Self)>; -} +impl StatementCallbackReturnType for CallbackHelper +where + F: Future> + Send, + S: 'static, +{ + type Return<'a> = PrepareFuture<'a, C, S>; -impl StmtCache { - pub fn new() -> Self { - Self { - cache: HashMap::new(), - } + fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a> { + Either::Left(future::ready(Err(e))) } - pub fn cached_prepared_statement<'a, F>( - &'a mut self, - cache_key: StatementCacheKey, - sql: String, - is_query_safe_to_cache: bool, - metadata: &[DB::TypeMetadata], - prepare_fn: F, - instrumentation: &std::sync::Mutex>>, - ) -> PrepareFuture<'a, F, S> + fn map_to_no_cache<'a>(self) -> Self::Return<'a> where - S: Send, - DB::QueryBuilder: Default, - DB::TypeMetadata: Clone + Send + Sync, - F: PrepareCallback + Send + 'a, - StatementCacheKey: Hash + Eq, + Self: 'a, { - use std::collections::hash_map::Entry::{Occupied, Vacant}; - - if !is_query_safe_to_cache { - let metadata = metadata.to_vec(); - let f = async move { - let stmt = prepare_fn - .prepare(&sql, &metadata, PrepareForCache::No) - .await?; - Ok((MaybeCached::CannotCache(stmt.0), stmt.1)) - } - .boxed(); - return future::Either::Right(f); - } + Either::Right( + self.0 + .map_ok(|(stmt, conn)| (MaybeCached::CannotCache(stmt), conn)) + .boxed(), + ) + } - match self.cache.entry(cache_key) { - Occupied(entry) => future::Either::Left(future::ready(Ok(( - MaybeCached::Cached(entry.into_mut()), - prepare_fn, - )))), - Vacant(entry) => { - let metadata = metadata.to_vec(); - instrumentation - .lock() - .unwrap_or_else(|p| p.into_inner()) - .on_connection_event(InstrumentationEvent::cache_query(&sql)); - let f = async move { - let statement = prepare_fn - .prepare(&sql, &metadata, PrepareForCache::Yes) - .await?; + fn map_to_cache(stmt: &mut S, conn: C) -> Self::Return<'_> { + Either::Left(future::ready(Ok((MaybeCached::Cached(stmt), conn)))) + } - Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1)) - } - .boxed(); - future::Either::Right(f) - } - } + fn register_cache<'a>( + self, + callback: impl FnOnce(S) -> &'a mut S + Send + 'a, + ) -> Self::Return<'a> + where + Self: 'a, + { + Either::Right( + self.0 + .map_ok(|(stmt, conn)| (MaybeCached::Cached(callback(stmt)), conn)) + .boxed(), + ) } } + +pub(crate) struct QueryFragmentHelper { + pub(crate) sql: String, + pub(crate) safe_to_cache: bool, +} diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 76a06da..cbb8436 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -6,33 +6,30 @@ //! //! * using a sync Connection implementation in async context //! * using the same code base for async crates needing multiple backends - -use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager}; -use diesel::backend::{Backend, DieselReserveSpecialization}; -use diesel::connection::Instrumentation; -use diesel::connection::{ - Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup, -}; -use diesel::query_builder::{ - AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId, -}; -use diesel::row::IntoOwnedRow; -use diesel::{ConnectionResult, QueryResult}; -use futures_util::future::BoxFuture; -use futures_util::stream::BoxStream; -use futures_util::{FutureExt, StreamExt, TryFutureExt}; -use std::marker::PhantomData; -use std::sync::{Arc, Mutex}; -use tokio::task::JoinError; +use futures_core::future::BoxFuture; +use std::error::Error; #[cfg(feature = "sqlite")] mod sqlite; -fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error { - diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::UnableToSendCommand, - Box::new(join_error.to_string()), - ) +/// This is a helper trait that allows to customize the +/// spawning blocking tasks as part of the +/// [`SyncConnectionWrapper`] type. By default a +/// tokio runtime and its spawn_blocking function is used. +pub trait SpawnBlocking { + /// This function should allow to execute a + /// given blocking task without blocking the caller + /// to get the result + fn spawn_blocking<'a, R>( + &mut self, + task: impl FnOnce() -> R + Send + 'static, + ) -> BoxFuture<'a, Result>> + where + R: Send + 'static; + + /// This function should be used to construct + /// a new runtime instance + fn get_runtime() -> Self; } /// A wrapper of a [`diesel::connection::Connection`] usable in async context. @@ -73,220 +70,135 @@ fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error { /// # some_async_fn().await; /// # } /// ``` -pub struct SyncConnectionWrapper { - inner: Arc>, -} - -#[async_trait::async_trait] -impl SimpleAsyncConnection for SyncConnectionWrapper -where - C: diesel::connection::Connection + 'static, -{ - async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { - let query = query.to_string(); - self.spawn_blocking(move |inner| inner.batch_execute(query.as_str())) - .await - } -} +#[cfg(feature = "tokio")] +pub type SyncConnectionWrapper = + self::implementation::SyncConnectionWrapper; -#[async_trait::async_trait] -impl AsyncConnection for SyncConnectionWrapper -where - // Backend bounds - ::Backend: std::default::Default + DieselReserveSpecialization, - ::QueryBuilder: std::default::Default, - // Connection bounds - C: Connection + LoadConnection + WithMetadataLookup + 'static, - ::TransactionManager: Send, - // BindCollector bounds - MD: Send + 'static, - for<'a> ::BindCollector<'a>: - MoveableBindCollector + std::default::Default, - // Row bounds - O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, - for<'conn, 'query> ::Row<'conn, 'query>: - IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, -{ - type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; - type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; - type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; - type Row<'conn, 'query> = O; - type Backend = ::Backend; - type TransactionManager = SyncTransactionManagerWrapper<::TransactionManager>; - - async fn establish(database_url: &str) -> ConnectionResult { - let database_url = database_url.to_string(); - tokio::task::spawn_blocking(move || C::establish(&database_url)) - .await - .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) - .map(|c| SyncConnectionWrapper::new(c)) - } +/// A wrapper of a [`diesel::connection::Connection`] usable in async context. +/// +/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements: +/// * it's a [`diesel::connection::LoadConnection`] +/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`] +/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`] +/// +/// Internally this wrapper type will use `spawn_blocking` on given type implementing [`SpawnBlocking`] trait +/// to execute the request on the inner connection. +#[cfg(not(feature = "tokio"))] +pub use self::implementation::SyncConnectionWrapper; - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> - where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query, - { - self.execute_with_prepared_query(source.as_query(), |conn, query| { - use diesel::row::IntoOwnedRow; - let mut cache = <<::Row<'_, '_> as IntoOwnedRow< - ::Backend, - >>::Cache as Default>::default(); - let cursor = conn.load(&query)?; - - let size_hint = cursor.size_hint(); - let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0)); - // we use an explicit loop here to easily propagate possible errors - // as early as possible - for row in cursor { - out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache))); - } +pub use self::implementation::SyncTransactionManagerWrapper; - Ok(out) - }) - .map_ok(|rows| futures_util::stream::iter(rows).boxed()) - .boxed() - } +mod implementation { + use crate::{AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection, TransactionManager}; + use diesel::backend::{Backend, DieselReserveSpecialization}; + use diesel::connection::{CacheSize, Instrumentation}; + use diesel::connection::{ + Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup, + }; + use diesel::query_builder::{ + AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId, + }; + use diesel::row::IntoOwnedRow; + use diesel::{ConnectionResult, QueryResult}; + use futures_core::stream::BoxStream; + use futures_util::{FutureExt, StreamExt, TryFutureExt}; + use std::marker::PhantomData; + use std::sync::{Arc, Mutex}; - fn execute_returning_count<'query, T>(&mut self, source: T) -> Self::ExecuteFuture<'_, 'query> - where - T: QueryFragment + QueryId, - { - self.execute_with_prepared_query(source, |conn, query| conn.execute_returning_count(&query)) - } + use super::*; - fn transaction_state( - &mut self, - ) -> &mut >::TransactionStateData { - self.exclusive_connection().transaction_state() + fn from_spawn_blocking_error( + error: Box, + ) -> diesel::result::Error { + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UnableToSendCommand, + Box::new(error.to_string()), + ) } - fn instrumentation(&mut self) -> &mut dyn Instrumentation { - // there should be no other pending future when this is called - // that means there is only one instance of this arc and - // we can simply access the inner data - if let Some(inner) = Arc::get_mut(&mut self.inner) { - inner - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .instrumentation() - } else { - panic!("Cannot access shared instrumentation") - } + pub struct SyncConnectionWrapper { + inner: Arc>, + runtime: S, } - fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { - // there should be no other pending future when this is called - // that means there is only one instance of this arc and - // we can simply access the inner data - if let Some(inner) = Arc::get_mut(&mut self.inner) { - inner - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .set_instrumentation(instrumentation) - } else { - panic!("Cannot access shared instrumentation") + impl SimpleAsyncConnection for SyncConnectionWrapper + where + C: diesel::connection::Connection + 'static, + S: SpawnBlocking + Send, + { + async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { + let query = query.to_string(); + self.spawn_blocking(move |inner| inner.batch_execute(query.as_str())) + .await } } -} -/// A wrapper of a diesel transaction manager usable in async context. -pub struct SyncTransactionManagerWrapper(PhantomData); - -#[async_trait::async_trait] -impl TransactionManager> for SyncTransactionManagerWrapper -where - SyncConnectionWrapper: AsyncConnection, - C: Connection + 'static, - T: diesel::connection::TransactionManager + Send, -{ - type TransactionStateData = T::TransactionStateData; - - async fn begin_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { - conn.spawn_blocking(move |inner| T::begin_transaction(inner)) - .await - } + impl AsyncConnectionCore for SyncConnectionWrapper + where + // Backend bounds + ::Backend: std::default::Default + DieselReserveSpecialization, + ::QueryBuilder: std::default::Default, + // Connection bounds + C: Connection + LoadConnection + WithMetadataLookup + 'static, + ::TransactionManager: Send, + // BindCollector bounds + MD: Send + 'static, + for<'a> ::BindCollector<'a>: + MoveableBindCollector + std::default::Default, + // Row bounds + O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, + for<'conn, 'query> ::Row<'conn, 'query>: + IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, + // SpawnBlocking bounds + S: SpawnBlocking + Send, + { + type LoadFuture<'conn, 'query> = + BoxFuture<'query, QueryResult>>; + type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; + type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; + type Row<'conn, 'query> = O; + type Backend = ::Backend; - async fn commit_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { - conn.spawn_blocking(move |inner| T::commit_transaction(inner)) - .await - } + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, + { + self.execute_with_prepared_query(source.as_query(), |conn, query| { + use diesel::row::IntoOwnedRow; + let mut cache = <<::Row<'_, '_> as IntoOwnedRow< + ::Backend, + >>::Cache as Default>::default(); + let cursor = conn.load(&query)?; - async fn rollback_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { - conn.spawn_blocking(move |inner| T::rollback_transaction(inner)) - .await - } + let size_hint = cursor.size_hint(); + let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0)); + // we use an explicit loop here to easily propagate possible errors + // as early as possible + for row in cursor { + out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache))); + } - fn transaction_manager_status_mut( - conn: &mut SyncConnectionWrapper, - ) -> &mut TransactionManagerStatus { - T::transaction_manager_status_mut(conn.exclusive_connection()) - } -} - -impl SyncConnectionWrapper { - /// Builds a wrapper with this underlying sync connection - pub fn new(connection: C) -> Self - where - C: Connection, - { - SyncConnectionWrapper { - inner: Arc::new(Mutex::new(connection)), + Ok(out) + }) + .map_ok(|rows| futures_util::stream::iter(rows).boxed()) + .boxed() } - } - /// Run a operation directly with the inner connection - /// - /// This function is usful to register custom functions - /// and collection for Sqlite for example - /// - /// # Example - /// - /// ```rust - /// # include!("../doctest_setup.rs"); - /// # #[tokio::main] - /// # async fn main() { - /// # run_test().await.unwrap(); - /// # } - /// # - /// # async fn run_test() -> QueryResult<()> { - /// # let mut conn = establish_connection().await; - /// conn.spawn_blocking(|conn| { - /// // sqlite.rs sqlite NOCASE only works for ASCII characters, - /// // this collation allows handling UTF-8 (barring locale differences) - /// conn.register_collation("RUSTNOCASE", |rhs, lhs| { - /// rhs.to_lowercase().cmp(&lhs.to_lowercase()) - /// }) - /// }).await - /// - /// # } - /// ``` - pub fn spawn_blocking<'a, R>( - &mut self, - task: impl FnOnce(&mut C) -> QueryResult + Send + 'static, - ) -> BoxFuture<'a, QueryResult> - where - C: Connection + 'static, - R: Send + 'static, - { - let inner = self.inner.clone(); - tokio::task::spawn_blocking(move || { - let mut inner = inner.lock().unwrap_or_else(|poison| { - // try to be resilient by providing the guard - inner.clear_poison(); - poison.into_inner() - }); - task(&mut inner) - }) - .unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err))) - .boxed() + fn execute_returning_count<'query, T>( + &mut self, + source: T, + ) -> Self::ExecuteFuture<'_, 'query> + where + T: QueryFragment + QueryId, + { + self.execute_with_prepared_query(source, |conn, query| { + conn.execute_returning_count(&query) + }) + } } - fn execute_with_prepared_query<'a, MD, Q, R>( - &mut self, - query: Q, - callback: impl FnOnce(&mut C, &CollectedQuery) -> QueryResult + Send + 'static, - ) -> BoxFuture<'a, QueryResult> + impl AsyncConnection for SyncConnectionWrapper where // Backend bounds ::Backend: std::default::Default + DieselReserveSpecialization, @@ -296,75 +208,310 @@ impl SyncConnectionWrapper { ::TransactionManager: Send, // BindCollector bounds MD: Send + 'static, - for<'b> ::BindCollector<'b>: + for<'a> ::BindCollector<'a>: MoveableBindCollector + std::default::Default, - // Arguments/Return bounds - Q: QueryFragment + QueryId, - R: Send + 'static, + // Row bounds + O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, + for<'conn, 'query> ::Row<'conn, 'query>: + IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, + // SpawnBlocking bounds + S: SpawnBlocking + Send, { - let backend = C::Backend::default(); - - let (collect_bind_result, collector_data) = { - let exclusive = self.inner.clone(); - let mut inner = exclusive.lock().unwrap_or_else(|poison| { - // try to be resilient by providing the guard - exclusive.clear_poison(); - poison.into_inner() - }); - let mut bind_collector = - <::BindCollector<'_> as Default>::default(); - let metadata_lookup = inner.metadata_lookup(); - let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend); - let collector_data = bind_collector.moveable(); - - (result, collector_data) - }; - - let mut query_builder = <::QueryBuilder as Default>::default(); - let sql = query - .to_sql(&mut query_builder, &backend) - .map(|_| query_builder.finish()); - let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend); - - self.spawn_blocking(|inner| { - collect_bind_result?; - let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data); - callback(inner, &query) - }) + type TransactionManager = + SyncTransactionManagerWrapper<::TransactionManager>; + + async fn establish(database_url: &str) -> ConnectionResult { + let database_url = database_url.to_string(); + let mut runtime = S::get_runtime(); + + runtime + .spawn_blocking(move || C::establish(&database_url)) + .await + .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) + .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime)) + } + + fn transaction_state( + &mut self, + ) -> &mut >::TransactionStateData + { + self.exclusive_connection().transaction_state() + } + + fn instrumentation(&mut self) -> &mut dyn Instrumentation { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .instrumentation() + } else { + panic!("Cannot access shared instrumentation") + } + } + + fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .set_instrumentation(instrumentation) + } else { + panic!("Cannot access shared instrumentation") + } + } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .set_prepared_statement_cache_size(size) + } else { + panic!("Cannot access shared cache") + } + } } - /// Gets an exclusive access to the underlying diesel Connection - /// - /// It panics in case of shared access. - /// This is typically used only used during transaction. - pub(self) fn exclusive_connection(&mut self) -> &mut C + /// A wrapper of a diesel transaction manager usable in async context. + pub struct SyncTransactionManagerWrapper(PhantomData); + + impl TransactionManager> for SyncTransactionManagerWrapper where - C: Connection, + SyncConnectionWrapper: AsyncConnection, + C: Connection + 'static, + S: SpawnBlocking, + T: diesel::connection::TransactionManager + Send, { - // there should be no other pending future when this is called - // that means there is only one instance of this Arc and - // we can simply access the inner data - if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) { - conn_mutex - .get_mut() - .expect("Mutex is poisoned, a thread must have panicked holding it.") - } else { - panic!("Cannot access shared transaction state") + type TransactionStateData = T::TransactionStateData; + + async fn begin_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + conn.spawn_blocking(move |inner| T::begin_transaction(inner)) + .await + } + + async fn commit_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + conn.spawn_blocking(move |inner| T::commit_transaction(inner)) + .await + } + + async fn rollback_transaction(conn: &mut SyncConnectionWrapper) -> QueryResult<()> { + conn.spawn_blocking(move |inner| T::rollback_transaction(inner)) + .await + } + + fn transaction_manager_status_mut( + conn: &mut SyncConnectionWrapper, + ) -> &mut TransactionManagerStatus { + T::transaction_manager_status_mut(conn.exclusive_connection()) + } + } + + impl SyncConnectionWrapper { + /// Builds a wrapper with this underlying sync connection + pub fn new(connection: C) -> Self + where + C: Connection, + S: SpawnBlocking, + { + SyncConnectionWrapper { + inner: Arc::new(Mutex::new(connection)), + runtime: S::get_runtime(), + } + } + + /// Builds a wrapper with this underlying sync connection + /// and runtime for spawning blocking tasks + pub fn with_runtime(connection: C, runtime: S) -> Self + where + C: Connection, + S: SpawnBlocking, + { + SyncConnectionWrapper { + inner: Arc::new(Mutex::new(connection)), + runtime, + } + } + + /// Run a operation directly with the inner connection + /// + /// This function is usful to register custom functions + /// and collection for Sqlite for example + /// + /// # Example + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # #[tokio::main] + /// # async fn main() { + /// # run_test().await.unwrap(); + /// # } + /// # + /// # async fn run_test() -> QueryResult<()> { + /// # let mut conn = establish_connection().await; + /// conn.spawn_blocking(|conn| { + /// // sqlite.rs sqlite NOCASE only works for ASCII characters, + /// // this collation allows handling UTF-8 (barring locale differences) + /// conn.register_collation("RUSTNOCASE", |rhs, lhs| { + /// rhs.to_lowercase().cmp(&lhs.to_lowercase()) + /// }) + /// }).await + /// + /// # } + /// ``` + pub fn spawn_blocking<'a, R>( + &mut self, + task: impl FnOnce(&mut C) -> QueryResult + Send + 'static, + ) -> BoxFuture<'a, QueryResult> + where + C: Connection + 'static, + R: Send + 'static, + S: SpawnBlocking, + { + let inner = self.inner.clone(); + self.runtime + .spawn_blocking(move || { + let mut inner = inner.lock().unwrap_or_else(|poison| { + // try to be resilient by providing the guard + inner.clear_poison(); + poison.into_inner() + }); + task(&mut inner) + }) + .unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err))) + .boxed() + } + + fn execute_with_prepared_query<'a, MD, Q, R>( + &mut self, + query: Q, + callback: impl FnOnce(&mut C, &CollectedQuery) -> QueryResult + Send + 'static, + ) -> BoxFuture<'a, QueryResult> + where + // Backend bounds + ::Backend: std::default::Default + DieselReserveSpecialization, + ::QueryBuilder: std::default::Default, + // Connection bounds + C: Connection + LoadConnection + WithMetadataLookup + 'static, + ::TransactionManager: Send, + // BindCollector bounds + MD: Send + 'static, + for<'b> ::BindCollector<'b>: + MoveableBindCollector + std::default::Default, + // Arguments/Return bounds + Q: QueryFragment + QueryId, + R: Send + 'static, + // SpawnBlocking bounds + S: SpawnBlocking, + { + let backend = C::Backend::default(); + + let (collect_bind_result, collector_data) = { + let exclusive = self.inner.clone(); + let mut inner = exclusive.lock().unwrap_or_else(|poison| { + // try to be resilient by providing the guard + exclusive.clear_poison(); + poison.into_inner() + }); + let mut bind_collector = + <::BindCollector<'_> as Default>::default(); + let metadata_lookup = inner.metadata_lookup(); + let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend); + let collector_data = bind_collector.moveable(); + + (result, collector_data) + }; + + let mut query_builder = <::QueryBuilder as Default>::default(); + let sql = query + .to_sql(&mut query_builder, &backend) + .map(|_| query_builder.finish()); + let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend); + + self.spawn_blocking(|inner| { + collect_bind_result?; + let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data); + callback(inner, &query) + }) + } + + /// Gets an exclusive access to the underlying diesel Connection + /// + /// It panics in case of shared access. + /// This is typically used only used during transaction. + pub(self) fn exclusive_connection(&mut self) -> &mut C + where + C: Connection, + { + // there should be no other pending future when this is called + // that means there is only one instance of this Arc and + // we can simply access the inner data + if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) { + conn_mutex + .get_mut() + .expect("Mutex is poisoned, a thread must have panicked holding it.") + } else { + panic!("Cannot access shared transaction state") + } } } -} -#[cfg(any( - feature = "deadpool", - feature = "bb8", - feature = "mobc", - feature = "r2d2" -))] -impl crate::pooled_connection::PoolableConnection for SyncConnectionWrapper -where - Self: AsyncConnection, -{ - fn is_broken(&mut self) -> bool { - Self::TransactionManager::is_broken_transaction_manager(self) + #[cfg(any( + feature = "deadpool", + feature = "bb8", + feature = "mobc", + feature = "r2d2" + ))] + impl crate::pooled_connection::PoolableConnection for SyncConnectionWrapper + where + Self: AsyncConnection, + { + fn is_broken(&mut self) -> bool { + Self::TransactionManager::is_broken_transaction_manager(self) + } + } + + #[cfg(feature = "tokio")] + pub enum Tokio { + Handle(tokio::runtime::Handle), + Runtime(tokio::runtime::Runtime), + } + + #[cfg(feature = "tokio")] + impl SpawnBlocking for Tokio { + fn spawn_blocking<'a, R>( + &mut self, + task: impl FnOnce() -> R + Send + 'static, + ) -> BoxFuture<'a, Result>> + where + R: Send + 'static, + { + let fut = match self { + Tokio::Handle(handle) => handle.spawn_blocking(task), + Tokio::Runtime(runtime) => runtime.spawn_blocking(task), + }; + + fut.map_err(|err| Box::from(err)).boxed() + } + + fn get_runtime() -> Self { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + Tokio::Handle(handle) + } else { + let runtime = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + + Tokio::Runtime(runtime) + } + } } } diff --git a/src/transaction_manager.rs b/src/transaction_manager.rs index 57383e8..22115ac 100644 --- a/src/transaction_manager.rs +++ b/src/transaction_manager.rs @@ -7,6 +7,7 @@ use diesel::result::Error; use diesel::QueryResult; use scoped_futures::ScopedBoxFuture; use std::borrow::Cow; +use std::future::Future; use std::num::NonZeroU32; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -18,7 +19,6 @@ use crate::AsyncConnection; /// /// You will not need to interact with this trait, unless you are writing an /// implementation of [`AsyncConnection`]. -#[async_trait::async_trait] pub trait TransactionManager: Send { /// Data stored as part of the connection implementation /// to track the current transaction state of a connection @@ -29,21 +29,21 @@ pub trait TransactionManager: Send { /// If the transaction depth is greater than 0, /// this should create a savepoint instead. /// This function is expected to increment the transaction depth by 1. - async fn begin_transaction(conn: &mut Conn) -> QueryResult<()>; + fn begin_transaction(conn: &mut Conn) -> impl Future> + Send; /// Rollback the inner-most transaction or savepoint /// /// If the transaction depth is greater than 1, /// this should rollback to the most recent savepoint. /// This function is expected to decrement the transaction depth by 1. - async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>; + fn rollback_transaction(conn: &mut Conn) -> impl Future> + Send; /// Commit the inner-most transaction or savepoint /// /// If the transaction depth is greater than 1, /// this should release the most recent savepoint. /// This function is expected to decrement the transaction depth by 1. - async fn commit_transaction(conn: &mut Conn) -> QueryResult<()>; + fn commit_transaction(conn: &mut Conn) -> impl Future> + Send; /// Fetch the current transaction status as mutable /// @@ -57,27 +57,35 @@ pub trait TransactionManager: Send { /// /// Each implementation of this function needs to fulfill the documented /// behaviour of [`AsyncConnection::transaction`] - async fn transaction<'a, F, R, E>(conn: &mut Conn, callback: F) -> Result + fn transaction<'a, 'conn, F, R, E>( + conn: &'conn mut Conn, + callback: F, + ) -> impl Future> + Send + 'conn where F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result> + Send + 'a, E: From + Send, R: Send, + 'a: 'conn, { - Self::begin_transaction(conn).await?; - match callback(&mut *conn).await { - Ok(value) => { - Self::commit_transaction(conn).await?; - Ok(value) - } - Err(user_error) => match Self::rollback_transaction(conn).await { - Ok(()) => Err(user_error), - Err(Error::BrokenTransactionManager) => { - // In this case we are probably more interested by the - // original error, which likely caused this - Err(user_error) + async move { + let callback = callback; + + Self::begin_transaction(conn).await?; + match callback(&mut *conn).await { + Ok(value) => { + Self::commit_transaction(conn).await?; + Ok(value) } - Err(rollback_error) => Err(rollback_error.into()), - }, + Err(user_error) => match Self::rollback_transaction(conn).await { + Ok(()) => Err(user_error), + Err(Error::BrokenTransactionManager) => { + // In this case we are probably more interested by the + // original error, which likely caused this + Err(user_error) + } + Err(rollback_error) => Err(rollback_error.into()), + }, + } } } @@ -138,6 +146,11 @@ pub struct AnsiTransactionManager { // See https://github.com/weiznich/diesel_async/issues/198 for // details pub(crate) is_broken: Arc, + // this boolean flag tracks whether we are currently in this process + // of trying to commit the transaction. this is useful because if we + // are and we get a serialization failure, we might not want to attempt + // a rollback up the chain. + pub(crate) is_commit: bool, } impl AnsiTransactionManager { @@ -161,15 +174,21 @@ impl AnsiTransactionManager { { let is_broken = conn.transaction_state().is_broken.clone(); let state = Self::get_transaction_state(conn)?; - match state.transaction_depth() { - None => { - Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?; - Self::get_transaction_state(conn)? - .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?; - Ok(()) - } - Some(_depth) => Err(Error::AlreadyInTransaction), + if let Some(_depth) = state.transaction_depth() { + return Err(Error::AlreadyInTransaction); } + let instrumentation_depth = NonZeroU32::new(1); + + conn.instrumentation() + .on_connection_event(InstrumentationEvent::begin_transaction( + instrumentation_depth.expect("We know that 1 is not zero"), + )); + + // Keep remainder of this method in sync with `begin_transaction()`. + Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?; + Self::get_transaction_state(conn)? + .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?; + Ok(()) } // This function should be used to await any connection @@ -193,7 +212,6 @@ impl AnsiTransactionManager { } } -#[async_trait::async_trait] impl TransactionManager for AnsiTransactionManager where Conn: AsyncConnection, @@ -342,9 +360,18 @@ where conn.instrumentation() .on_connection_event(InstrumentationEvent::commit_transaction(depth)); - let is_broken = conn.transaction_state().is_broken.clone(); + let is_broken = { + let transaction_state = conn.transaction_state(); + transaction_state.is_commit = true; + transaction_state.is_broken.clone() + }; + + let res = + Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await; + + conn.transaction_state().is_commit = false; - match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await { + match res { Ok(()) => { match Self::get_transaction_state(conn)? .change_transaction_depth(TransactionDepthChange::DecreaseDepth) @@ -368,12 +395,8 @@ where .. }) = conn.transaction_state().status { - match Self::critical_transaction_block( - &is_broken, - Self::rollback_transaction(conn), - ) - .await - { + // rollback_transaction handles the critical block internally on its own + match Self::rollback_transaction(conn).await { Ok(()) => {} Err(rollback_error) => { conn.transaction_state().status.set_in_error(); @@ -383,6 +406,9 @@ where }); } } + } else { + Self::get_transaction_state(conn)? + .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?; } Err(commit_error) } diff --git a/tests/instrumentation.rs b/tests/instrumentation.rs index 039ebce..899189d 100644 --- a/tests/instrumentation.rs +++ b/tests/instrumentation.rs @@ -5,6 +5,7 @@ use diesel::connection::InstrumentationEvent; use diesel::query_builder::AsQuery; use diesel::QueryResult; use diesel_async::AsyncConnection; +use diesel_async::AsyncConnectionCore; use diesel_async::SimpleAsyncConnection; use std::num::NonZeroU32; use std::sync::Arc; @@ -54,9 +55,14 @@ impl From> for Event { } async fn setup_test_case() -> (Arc>>, TestConnection) { + setup_test_case_with_connection(connection_with_sean_and_tess_in_users_table().await) +} + +fn setup_test_case_with_connection( + mut conn: TestConnection, +) -> (Arc>>, TestConnection) { let events = Arc::new(Mutex::new(Vec::::new())); let events_to_check = events.clone(); - let mut conn = connection_with_sean_and_tess_in_users_table().await; conn.set_instrumentation(move |event: InstrumentationEvent<'_>| { events.lock().unwrap().push(event.into()); }); @@ -102,7 +108,7 @@ async fn check_events_are_emitted_for_execute_returning_count() { #[tokio::test] async fn check_events_are_emitted_for_load() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, users::table.as_query()) + let _ = AsyncConnectionCore::load(&mut conn, users::table.as_query()) .await .unwrap(); let events = events_to_check.lock().unwrap(); @@ -128,7 +134,7 @@ async fn check_events_are_emitted_for_execute_returning_count_does_not_contain_c #[tokio::test] async fn check_events_are_emitted_for_load_does_not_contain_cache_for_uncached_queries() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, diesel::sql_query("select 1")) + let _ = AsyncConnectionCore::load(&mut conn, diesel::sql_query("select 1")) .await .unwrap(); let events = events_to_check.lock().unwrap(); @@ -152,7 +158,7 @@ async fn check_events_are_emitted_for_execute_returning_count_does_contain_error #[tokio::test] async fn check_events_are_emitted_for_load_does_contain_error_for_failures() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, diesel::sql_query("invalid")).await; + let _ = AsyncConnectionCore::load(&mut conn, diesel::sql_query("invalid")).await; let events = events_to_check.lock().unwrap(); assert_eq!(events.len(), 2, "{:?}", events); assert_matches!(events[0], Event::StartQuery { .. }); @@ -180,10 +186,10 @@ async fn check_events_are_emitted_for_execute_returning_count_repeat_does_not_re #[tokio::test] async fn check_events_are_emitted_for_load_repeat_does_not_repeat_cache() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, users::table.as_query()) + let _ = AsyncConnectionCore::load(&mut conn, users::table.as_query()) .await .unwrap(); - let _ = AsyncConnection::load(&mut conn, users::table.as_query()) + let _ = AsyncConnectionCore::load(&mut conn, users::table.as_query()) .await .unwrap(); let events = events_to_check.lock().unwrap(); @@ -255,3 +261,26 @@ async fn check_events_transaction_nested() { assert_matches!(events[10], Event::StartQuery { .. }); assert_matches!(events[11], Event::FinishQuery { .. }); } + +#[cfg(feature = "postgres")] +#[tokio::test] +async fn check_events_transaction_builder() { + use crate::connection_without_transaction; + use diesel::result::Error; + use scoped_futures::ScopedFutureExt; + + let (events_to_check, mut conn) = + setup_test_case_with_connection(connection_without_transaction().await); + conn.build_transaction() + .run(|_tx| async move { Ok::<(), Error>(()) }.scope_boxed()) + .await + .unwrap(); + let events = events_to_check.lock().unwrap(); + assert_eq!(events.len(), 6, "{:?}", events); + assert_matches!(events[0], Event::BeginTransaction { .. }); + assert_matches!(events[1], Event::StartQuery { .. }); + assert_matches!(events[2], Event::FinishQuery { .. }); + assert_matches!(events[3], Event::CommitTransaction { .. }); + assert_matches!(events[4], Event::StartQuery { .. }); + assert_matches!(events[5], Event::FinishQuery { .. }); +} diff --git a/tests/lib.rs b/tests/lib.rs index 22701c8..5125e28 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -3,15 +3,16 @@ use diesel::QueryResult; use diesel_async::*; use scoped_futures::ScopedFutureExt; use std::fmt::Debug; -use std::pin::Pin; #[cfg(feature = "postgres")] mod custom_types; mod instrumentation; +mod notifications; #[cfg(any(feature = "bb8", feature = "deadpool", feature = "mobc"))] mod pooling; #[cfg(feature = "async-connection-wrapper")] mod sync_wrapper; +mod transactions; mod type_check; async fn transaction_test>( @@ -19,7 +20,7 @@ async fn transaction_test>( ) -> QueryResult<()> { let res = conn .transaction::(|conn| { - Box::pin(async move { + async move { let users: Vec = users::table.load(conn).await?; assert_eq!(&users[0].name, "John Doe"); assert_eq!(&users[1].name, "Jane Doe"); @@ -55,7 +56,8 @@ async fn transaction_test>( assert_eq!(count, 4); Err(diesel::result::Error::RollbackTransaction) - }) as Pin> + } + .scope_boxed() }) .await; assert_eq!( @@ -99,7 +101,7 @@ type TestConnection = sync_connection_wrapper::SyncConnectionWrapper; #[allow(dead_code)] -type TestBackend = ::Backend; +type TestBackend = ::Backend; #[tokio::test] async fn test_basic_insert_and_load() -> QueryResult<()> { @@ -203,8 +205,7 @@ async fn setup(connection: &mut TestConnection) { } async fn connection() -> TestConnection { - let db_url = std::env::var("DATABASE_URL").unwrap(); - let mut conn = TestConnection::establish(&db_url).await.unwrap(); + let mut conn = connection_without_transaction().await; if cfg!(feature = "postgres") { // postgres allows to modify the schema inside of a transaction conn.begin_test_transaction().await.unwrap(); @@ -218,3 +219,8 @@ async fn connection() -> TestConnection { } conn } + +async fn connection_without_transaction() -> TestConnection { + let db_url = std::env::var("DATABASE_URL").unwrap(); + TestConnection::establish(&db_url).await.unwrap() +} diff --git a/tests/notifications.rs b/tests/notifications.rs new file mode 100644 index 0000000..17b790b --- /dev/null +++ b/tests/notifications.rs @@ -0,0 +1,55 @@ +#[cfg(feature = "postgres")] +#[tokio::test] +async fn notifications_arrive() { + use diesel_async::RunQueryDsl; + use futures_util::{StreamExt, TryStreamExt}; + + let conn = &mut super::connection_without_transaction().await; + + diesel::sql_query("LISTEN test_notifications") + .execute(conn) + .await + .unwrap(); + + diesel::sql_query("NOTIFY test_notifications, 'first'") + .execute(conn) + .await + .unwrap(); + + diesel::sql_query("NOTIFY test_notifications, 'second'") + .execute(conn) + .await + .unwrap(); + + let notifications = conn + .notifications_stream() + .take(2) + .try_collect::>() + .await + .unwrap(); + + assert_eq!(2, notifications.len()); + assert_eq!(notifications[0].channel, "test_notifications"); + assert_eq!(notifications[1].channel, "test_notifications"); + assert_eq!(notifications[0].payload, "first"); + assert_eq!(notifications[1].payload, "second"); + + let next_notification = tokio::time::timeout( + std::time::Duration::from_secs(1), + std::pin::pin!(conn.notifications_stream()).next(), + ) + .await; + + assert!( + next_notification.is_err(), + "Got a next notification, while not expecting one: {next_notification:?}" + ); + + diesel::sql_query("NOTIFY test_notifications") + .execute(conn) + .await + .unwrap(); + + let next_notification = std::pin::pin!(conn.notifications_stream()).next().await; + assert_eq!(next_notification.unwrap().unwrap().payload, ""); +} diff --git a/tests/sync_wrapper.rs b/tests/sync_wrapper.rs index 791f89b..576f333 100644 --- a/tests/sync_wrapper.rs +++ b/tests/sync_wrapper.rs @@ -66,3 +66,30 @@ fn check_run_migration() { // just use `run_migrations` here because that's the easiest one without additional setup conn.run_migrations(&migrations).unwrap(); } + +#[tokio::test] +async fn test_sync_wrapper_unwrap() { + let db_url = std::env::var("DATABASE_URL").unwrap(); + + let conn = tokio::task::spawn_blocking(move || { + use diesel::RunQueryDsl; + + let mut conn = AsyncConnectionWrapper::::establish(&db_url).unwrap(); + let res = + diesel::select(1.into_sql::()).get_result::(&mut conn); + assert_eq!(Ok(1), res); + conn + }) + .await + .unwrap(); + + { + use diesel_async::RunQueryDsl; + + let mut conn = conn.into_inner(); + let res = diesel::select(1.into_sql::()) + .get_result::(&mut conn) + .await; + assert_eq!(Ok(1), res); + } +} diff --git a/tests/transactions.rs b/tests/transactions.rs new file mode 100644 index 0000000..6de782c --- /dev/null +++ b/tests/transactions.rs @@ -0,0 +1,228 @@ +#[cfg(feature = "postgres")] +#[tokio::test] +async fn concurrent_serializable_transactions_behave_correctly() { + use diesel::prelude::*; + use diesel_async::RunQueryDsl; + use std::sync::Arc; + use tokio::sync::Barrier; + + table! { + users3 { + id -> Integer, + } + } + + // create an async connection + let mut conn = super::connection_without_transaction().await; + + let mut conn1 = super::connection_without_transaction().await; + + diesel::sql_query("CREATE TABLE IF NOT EXISTS users3 (id int);") + .execute(&mut conn) + .await + .unwrap(); + + let barrier_1 = Arc::new(Barrier::new(2)); + let barrier_2 = Arc::new(Barrier::new(2)); + let barrier_1_for_tx1 = barrier_1.clone(); + let barrier_1_for_tx2 = barrier_1.clone(); + let barrier_2_for_tx1 = barrier_2.clone(); + let barrier_2_for_tx2 = barrier_2.clone(); + + let mut tx = conn.build_transaction().serializable().read_write(); + + let res = tx.run(|conn| { + Box::pin(async { + users3::table.select(users3::id).load::(conn).await?; + + barrier_1_for_tx1.wait().await; + diesel::insert_into(users3::table) + .values(users3::id.eq(1)) + .execute(conn) + .await?; + barrier_2_for_tx1.wait().await; + + Ok::<_, diesel::result::Error>(()) + }) + }); + + let mut tx1 = conn1.build_transaction().serializable().read_write(); + + let res1 = async { + let res = tx1 + .run(|conn| { + Box::pin(async { + users3::table.select(users3::id).load::(conn).await?; + + barrier_1_for_tx2.wait().await; + diesel::insert_into(users3::table) + .values(users3::id.eq(1)) + .execute(conn) + .await?; + + Ok::<_, diesel::result::Error>(()) + }) + }) + .await; + barrier_2_for_tx2.wait().await; + res + }; + + let (res, res1) = tokio::join!(res, res1); + let _ = diesel::sql_query("DROP TABLE users3") + .execute(&mut conn1) + .await; + + assert!( + res1.is_ok(), + "Expected the second transaction to be succussfull, but got an error: {:?}", + res1.unwrap_err() + ); + + assert!(res.is_err(), "Expected the first transaction to fail"); + let err = res.unwrap_err(); + assert!( + matches!( + &err, + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::SerializationFailure, + _ + ) + ), + "Expected an serialization failure but got another error: {err:?}" + ); + + let mut tx = conn.build_transaction(); + + let res = tx + .run(|_| Box::pin(async { Ok::<_, diesel::result::Error>(()) })) + .await; + + assert!( + res.is_ok(), + "Expect transaction to run fine but got an error: {:?}", + res.unwrap_err() + ); +} + +#[cfg(feature = "postgres")] +#[tokio::test] +async fn commit_with_serialization_failure_already_ends_transaction() { + use diesel::prelude::*; + use diesel_async::{AsyncConnection, RunQueryDsl}; + use std::sync::Arc; + use tokio::sync::Barrier; + + table! { + users4 { + id -> Integer, + } + } + + // create an async connection + let mut conn = super::connection_without_transaction().await; + + struct A(Vec<&'static str>); + impl diesel::connection::Instrumentation for A { + fn on_connection_event(&mut self, event: diesel::connection::InstrumentationEvent<'_>) { + if let diesel::connection::InstrumentationEvent::StartQuery { query, .. } = event { + let q = query.to_string(); + let q = q.split_once(' ').map(|(a, _)| a).unwrap_or(&q); + + if matches!(q, "BEGIN" | "COMMIT" | "ROLLBACK") { + assert_eq!(q, self.0.pop().unwrap()); + } + } + } + } + conn.set_instrumentation(A(vec!["COMMIT", "BEGIN", "COMMIT", "BEGIN"])); + + let mut conn1 = super::connection_without_transaction().await; + + diesel::sql_query("CREATE TABLE IF NOT EXISTS users4 (id int);") + .execute(&mut conn) + .await + .unwrap(); + + let barrier_1 = Arc::new(Barrier::new(2)); + let barrier_2 = Arc::new(Barrier::new(2)); + let barrier_1_for_tx1 = barrier_1.clone(); + let barrier_1_for_tx2 = barrier_1.clone(); + let barrier_2_for_tx1 = barrier_2.clone(); + let barrier_2_for_tx2 = barrier_2.clone(); + + let mut tx = conn.build_transaction().serializable().read_write(); + + let res = tx.run(|conn| { + Box::pin(async { + users4::table.select(users4::id).load::(conn).await?; + + barrier_1_for_tx1.wait().await; + diesel::insert_into(users4::table) + .values(users4::id.eq(1)) + .execute(conn) + .await?; + barrier_2_for_tx1.wait().await; + + Ok::<_, diesel::result::Error>(()) + }) + }); + + let mut tx1 = conn1.build_transaction().serializable().read_write(); + + let res1 = async { + let res = tx1 + .run(|conn| { + Box::pin(async { + users4::table.select(users4::id).load::(conn).await?; + + barrier_1_for_tx2.wait().await; + diesel::insert_into(users4::table) + .values(users4::id.eq(1)) + .execute(conn) + .await?; + + Ok::<_, diesel::result::Error>(()) + }) + }) + .await; + barrier_2_for_tx2.wait().await; + res + }; + + let (res, res1) = tokio::join!(res, res1); + let _ = diesel::sql_query("DROP TABLE users4") + .execute(&mut conn1) + .await; + + assert!( + res1.is_ok(), + "Expected the second transaction to be succussfull, but got an error: {:?}", + res1.unwrap_err() + ); + + assert!(res.is_err(), "Expected the first transaction to fail"); + let err = res.unwrap_err(); + assert!( + matches!( + &err, + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::SerializationFailure, + _ + ) + ), + "Expected an serialization failure but got another error: {err:?}" + ); + + let mut tx = conn.build_transaction(); + + let res = tx + .run(|_| Box::pin(async { Ok::<_, diesel::result::Error>(()) })) + .await; + + assert!( + res.is_ok(), + "Expect transaction to run fine but got an error: {:?}", + res.unwrap_err() + ); +} diff --git a/tests/type_check.rs b/tests/type_check.rs index 52ff8c3..a074796 100644 --- a/tests/type_check.rs +++ b/tests/type_check.rs @@ -4,14 +4,14 @@ use diesel::expression::{AsExpression, ValidGrouping}; use diesel::prelude::*; use diesel::query_builder::{NoFromClause, QueryFragment, QueryId}; use diesel::sql_types::{self, HasSqlType, SingleValue}; -use diesel_async::{AsyncConnection, RunQueryDsl}; +use diesel_async::{AsyncConnectionCore, RunQueryDsl}; use std::fmt::Debug; async fn type_check(conn: &mut TestConnection, value: T) where T: Clone + AsExpression - + FromSqlRow::Backend> + + FromSqlRow::Backend> + Send + PartialEq + Debug @@ -19,10 +19,10 @@ where + 'static, T::Expression: ValidGrouping<()> + SelectableExpression - + QueryFragment<::Backend> + + QueryFragment<::Backend> + QueryId + Send, - ::Backend: HasSqlType, + ::Backend: HasSqlType, ST: SingleValue, { let res = diesel::select(value.clone().into_sql())