diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 71c844e0c..f1edc7ec3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -16,7 +16,6 @@ concurrency: jobs: build: - runs-on: ubuntu-latest timeout-minutes: 60 strategy: @@ -48,7 +47,13 @@ jobs: db-org: valkey-io, db-name: valkey, db-version: 8.0.1 - }, + }, + { + rust: stable, + db-org: redis, + db-name: redis, + db-version: 8.0-rc1 + }, # Different rust cases { @@ -188,7 +193,7 @@ jobs: - run: cargo clippy --all-features --all-targets -- -D warnings name: clippy - name: doc - run: cargo doc --no-deps --document-private-items + run: cargo doc --all-features --no-deps --document-private-items env: RUSTDOCFLAGS: -Dwarnings @@ -250,5 +255,21 @@ jobs: - uses: Swatinem/rust-cache@v2 - run: | - cargo install --git https://github.com/TheBevyFlock/flag-frenzy.git + cargo install --git https://github.com/nihohit/flag-frenzy.git flag-frenzy --package redis + + windows-build: + runs-on: windows-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain/@master + with: + toolchain: stable + + - uses: Swatinem/rust-cache@v2 + + - name: Build + run: make build diff --git a/Cargo.lock b/Cargo.lock index 89dc7257a..5e9370e9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "afl" -version = "0.15.14" +version = "0.15.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe38b594cedcc69d8d58022f17a98438e9a2b18c9028a0c3db4bc706f32ab8e7" +checksum = "92ab76b4a49c1d3dcd5032a3c365670838db36b3154716046d57a4a3ce4298ec" dependencies = [ "home", "libc", @@ -82,9 +82,9 @@ checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" [[package]] name = "arc-swap" @@ -124,7 +124,7 @@ dependencies = [ "concurrent-queue", "event-listener-strategy", "futures-core", - "pin-project-lite", + "pin-project-lite 0.2.16", ] [[package]] @@ -140,6 +140,17 @@ dependencies = [ "slab", ] +[[package]] +name = "async-fs" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcd09b382f40fcd159c2d695175b2ae620ffa5f3bd6f664131efff4e8b9e04a" +dependencies = [ + "async-lock 3.4.0", + "blocking", + "futures-lite 2.6.0", +] + [[package]] name = "async-global-executor" version = "2.4.1" @@ -211,7 +222,7 @@ checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ "event-listener 5.4.0", "event-listener-strategy", - "pin-project-lite", + "pin-project-lite 0.2.16", ] [[package]] @@ -226,6 +237,54 @@ dependencies = [ "url", ] +[[package]] +name = "async-net" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" +dependencies = [ + "async-io 2.4.0", + "blocking", + "futures-lite 2.6.0", +] + +[[package]] +name = "async-process" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63255f1dc2381611000436537bbedfe83183faa303a5a0edaf191edef06526bb" +dependencies = [ + "async-channel 2.3.1", + "async-io 2.4.0", + "async-lock 3.4.0", + "async-signal", + "async-task", + "blocking", + "cfg-if", + "event-listener 5.4.0", + "futures-lite 2.6.0", + "rustix 0.38.44", + "tracing", +] + +[[package]] +name = "async-signal" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfb3634b73397aa844481f814fad23bbf07fdb0eabec10f2eb95e58944b1ec32" +dependencies = [ + "async-io 2.4.0", + "async-lock 3.4.0", + "atomic-waker", + "cfg-if", + "futures-core", + "futures-io", + "rustix 0.38.44", + "signal-hook-registry", + "slab", + "windows-sys 0.52.0", +] + [[package]] name = "async-std" version = "1.13.0" @@ -246,7 +305,7 @@ dependencies = [ "log", "memchr", "once_cell", - "pin-project-lite", + "pin-project-lite 0.2.16", "pin-utils", "slab", "wasm-bindgen-futures", @@ -272,9 +331,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "backon" -version = "1.3.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba5289ec98f68f28dd809fd601059e6aa908bb8f6108620930828283d4ee23d7" +checksum = "970d91570c01a8a5959b36ad7dd1c30642df24b6b3068710066f6809f7033bb7" dependencies = [ "fastrand 2.3.0", ] @@ -294,6 +353,17 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "bb8" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "212d8b8e1a22743d9241575c6ba822cf9c8fef34771c86ab7e477a4fbfd254e5" +dependencies = [ + "futures-util", + "parking_lot", + "tokio", +] + [[package]] name = "bigdecimal" version = "0.4.7" @@ -495,7 +565,7 @@ dependencies = [ "bytes", "futures-core", "memchr", - "pin-project-lite", + "pin-project-lite 0.2.16", "tokio", "tokio-util", ] @@ -519,6 +589,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -655,7 +735,7 @@ checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", - "pin-project-lite", + "pin-project-lite 0.2.16", ] [[package]] @@ -665,7 +745,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2" dependencies = [ "event-listener 5.4.0", - "pin-project-lite", + "pin-project-lite 0.2.16", ] [[package]] @@ -784,7 +864,7 @@ dependencies = [ "futures-io", "memchr", "parking", - "pin-project-lite", + "pin-project-lite 0.2.16", "waker-fn", ] @@ -798,7 +878,7 @@ dependencies = [ "futures-core", "futures-io", "parking", - "pin-project-lite", + "pin-project-lite 0.2.16", ] [[package]] @@ -844,7 +924,7 @@ dependencies = [ "async-channel 1.9.0", "async-io 1.13.0", "futures-core", - "pin-project-lite", + "pin-project-lite 0.2.16", ] [[package]] @@ -866,7 +946,7 @@ dependencies = [ "futures-sink", "futures-task", "memchr", - "pin-project-lite", + "pin-project-lite 0.2.16", "pin-utils", "slab", ] @@ -1193,9 +1273,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.169" +version = "0.2.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" [[package]] name = "libm" @@ -1215,6 +1295,12 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "linux-raw-sys" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" + [[package]] name = "litemap" version = "0.7.4" @@ -1233,9 +1319,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.25" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" dependencies = [ "value-bag", ] @@ -1287,7 +1373,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -1343,9 +1429,9 @@ checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] name = "openssl" -version = "0.10.70" +version = "0.10.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" dependencies = [ "bitflags 2.8.0", "cfg-if", @@ -1375,9 +1461,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.105" +version = "0.9.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" +checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" dependencies = [ "cc", "libc", @@ -1431,6 +1517,12 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project-lite" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "257b64915a082f7811703966789728173279bdebb956b143dbcd23f6f970a777" + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1500,7 +1592,7 @@ dependencies = [ "concurrent-queue", "libc", "log", - "pin-project-lite", + "pin-project-lite 0.2.16", "windows-sys 0.48.0", ] @@ -1513,7 +1605,7 @@ dependencies = [ "cfg-if", "concurrent-queue", "hermit-abi 0.4.0", - "pin-project-lite", + "pin-project-lite 0.2.16", "rustix 0.38.44", "tracing", "windows-sys 0.59.0", @@ -1686,17 +1778,20 @@ dependencies = [ [[package]] name = "redis" -version = "0.29.0" +version = "0.30.0" dependencies = [ "ahash 0.8.11", "anyhow", "arc-swap", "assert_approx_eq", + "async-io 2.4.0", "async-native-tls", "async-std", "backon", + "bb8", "bigdecimal", "bytes", + "cfg-if", "combine", "crc16", "criterion", @@ -1716,7 +1811,7 @@ dependencies = [ "once_cell", "partial-io", "percent-encoding", - "pin-project-lite", + "pin-project-lite 0.2.16", "quickcheck", "r2d2", "rand 0.9.0", @@ -1729,7 +1824,9 @@ dependencies = [ "serde", "serde_json", "sha1_smol", - "socket2 0.5.8", + "smol", + "smol-timeout", + "socket2 0.5.9", "tempfile", "tokio", "tokio-native-tls", @@ -1742,13 +1839,13 @@ dependencies = [ [[package]] name = "redis-test" -version = "0.9.0" +version = "0.10.0" dependencies = [ "bytes", "futures", "rand 0.9.0", "redis", - "socket2 0.5.8", + "socket2 0.5.9", "tempfile", "tokio", ] @@ -1808,15 +1905,14 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.8" +version = "0.17.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" dependencies = [ "cc", "cfg-if", "getrandom 0.2.15", "libc", - "spin", "untrusted", "windows-sys 0.52.0", ] @@ -1938,11 +2034,24 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "rustix" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a04a32fb43fcdef85b977ce7f77a150805e4b2ea1f2656898d4a547dde78df6" +dependencies = [ + "bitflags 2.8.0", + "errno", + "libc", + "linux-raw-sys 0.9.3", + "windows-sys 0.59.0", +] + [[package]] name = "rustls" -version = "0.23.23" +version = "0.23.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c" dependencies = [ "once_cell", "ring", @@ -1954,24 +2063,14 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.3" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile", "rustls-pki-types", "schannel", - "security-framework", -] - -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", + "security-framework 3.1.0", ] [[package]] @@ -1982,9 +2081,9 @@ checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" dependencies = [ "ring", "rustls-pki-types", @@ -2049,7 +2148,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.8.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc" +dependencies = [ + "bitflags 2.8.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -2057,9 +2169,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5" dependencies = [ "core-foundation-sys", "libc", @@ -2073,18 +2185,18 @@ checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", @@ -2093,9 +2205,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" dependencies = [ "itoa", "memchr", @@ -2115,6 +2227,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "simdutf8" version = "0.1.5" @@ -2136,6 +2257,33 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smol" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33bd3e260892199c3ccfc487c88b2da2265080acb316cd920da72fdfd7c599f" +dependencies = [ + "async-channel 2.3.1", + "async-executor", + "async-fs", + "async-io 2.4.0", + "async-lock 3.4.0", + "async-net", + "async-process", + "blocking", + "futures-lite 2.6.0", +] + +[[package]] +name = "smol-timeout" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "847d777e2c6c166bad26264479e80a9820f3d364fcb4a0e23cd57bbfa8e94961" +dependencies = [ + "async-io 1.13.0", + "pin-project-lite 0.1.12", +] + [[package]] name = "socket2" version = "0.4.10" @@ -2148,20 +2296,14 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" dependencies = [ "libc", "windows-sys 0.52.0", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2215,15 +2357,14 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.16.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" dependencies = [ - "cfg-if", "fastrand 2.3.0", "getrandom 0.3.1", "once_cell", - "rustix 0.38.44", + "rustix 1.0.4", "windows-sys 0.59.0", ] @@ -2284,16 +2425,17 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.43.0" +version = "1.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" dependencies = [ "backtrace", "bytes", "libc", "mio", - "pin-project-lite", - "socket2 0.5.8", + "parking_lot", + "pin-project-lite 0.2.16", + "socket2 0.5.9", "tokio-macros", "windows-sys 0.52.0", ] @@ -2321,9 +2463,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ "rustls", "tokio", @@ -2331,14 +2473,14 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" dependencies = [ "bytes", "futures-core", "futures-sink", - "pin-project-lite", + "pin-project-lite 0.2.16", "tokio", ] @@ -2365,7 +2507,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ - "pin-project-lite", + "pin-project-lite 0.2.16", "tracing-core", ] @@ -2412,9 +2554,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.13.1" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0" +checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" [[package]] name = "valkey" diff --git a/Makefile b/Makefile index 093f0b7ef..f3216e83b 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,7 @@ test: @echo "====================================================================" @echo "Testing Connection Type TCP with native-TLS support" @echo "====================================================================" - @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp+tls RUST_BACKTRACE=1 cargo nextest run --locked -p redis --features=json,tokio-native-tls-comp,async-std-native-tls-comp,connection-manager,cluster-async -E 'not test(test_module)' + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp+tls RUST_BACKTRACE=1 cargo nextest run --locked -p redis --features=json,tokio-native-tls-comp,async-std-native-tls-comp,smol-native-tls-comp,connection-manager,cluster-async -E 'not test(test_module)' @echo "====================================================================" @echo "Testing Connection Type UNIX" @@ -70,7 +70,7 @@ bench: cargo bench --all-features docs: - @RUSTFLAGS="-D warnings" RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --no-deps + @RUSTDOCFLAGS="-D warnings --cfg docsrs" cargo +nightly doc --all-features --no-deps upload-docs: docs @./upload-docs.sh diff --git a/README.md b/README.md index 9065d45cd..874c652e0 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ The crate is called `redis` and you can depend on it via cargo: ```ini [dependencies] -redis = "0.29.0" +redis = "0.30.0" ``` Documentation on the library can be found at @@ -53,14 +53,17 @@ you can implement the `FromRedisValue` and `ToRedisArgs` traits, or derive it wi ## Async support To enable asynchronous clients, enable the relevant feature in your Cargo.toml, -`tokio-comp` for tokio users or `async-std-comp` for async-std users. +`tokio-comp` for tokio users, `smol-comp` for smol users, or `async-std-comp` for async-std users. ``` # if you use tokio -redis = { version = "0.29.0", features = ["tokio-comp"] } +redis = { version = "0.30.0", features = ["tokio-comp"] } + +# if you use smol +redis = { version = "0.30.0", features = ["smol-comp"] } # if you use async-std -redis = { version = "0.29.0", features = ["async-std-comp"] } +redis = { version = "0.30.0", features = ["async-std-comp"] } ``` ## Connection Pooling @@ -69,7 +72,7 @@ When using a sync connection, it is recommended to use a connection pool in orde disconnects or multi-threaded usage. This can be done using the `r2d2` feature. ``` -redis = { version = "0.29.0", features = ["r2d2"] } +redis = { version = "0.30.0", features = ["r2d2"] } ``` For async connections, connection pooling isn't necessary, unless blocking commands are used. @@ -91,25 +94,31 @@ Currently, `native-tls` and `rustls` are supported. To use `native-tls`: ``` -redis = { version = "0.29.0", features = ["tls-native-tls"] } +redis = { version = "0.30.0", features = ["tls-native-tls"] } # if you use tokio -redis = { version = "0.29.0", features = ["tokio-native-tls-comp"] } +redis = { version = "0.30.0", features = ["tokio-native-tls-comp"] } + +# if you use smol +redis = { version = "0.30.0", features = ["smol-native-tls-comp"] } # if you use async-std -redis = { version = "0.29.0", features = ["async-std-native-tls-comp"] } +redis = { version = "0.30.0", features = ["async-std-native-tls-comp"] } ``` To use `rustls`: ``` -redis = { version = "0.29.0", features = ["tls-rustls"] } +redis = { version = "0.30.0", features = ["tls-rustls"] } # if you use tokio -redis = { version = "0.29.0", features = ["tokio-rustls-comp"] } +redis = { version = "0.30.0", features = ["tokio-rustls-comp"] } + +# if you use smol +redis = { version = "0.30.0", features = ["smol-rustls-comp"] } # if you use async-std -redis = { version = "0.29.0", features = ["async-std-rustls-comp"] } +redis = { version = "0.30.0", features = ["async-std-rustls-comp"] } ``` Add `rustls` to dependencies @@ -150,7 +159,7 @@ let client = redis::Client::open("rediss://127.0.0.1/#insecure")?; Support for Redis Cluster can be enabled by enabling the `cluster` feature in your Cargo.toml: -`redis = { version = "0.29.0", features = [ "cluster"] }` +`redis = { version = "0.30.0", features = [ "cluster"] }` Then you can simply use the `ClusterClient`, which accepts a list of available nodes. Note that only one node in the cluster needs to be specified when instantiating the client, though @@ -173,7 +182,7 @@ fn fetch_an_integer() -> String { Async Redis Cluster support can be enabled by enabling the `cluster-async` feature, along with your preferred async runtime, e.g.: -`redis = { version = "0.29.0", features = [ "cluster-async", "tokio-std-comp" ] }` +`redis = { version = "0.30.0", features = [ "cluster-async", "tokio-std-comp" ] }` ```rust use redis::cluster::ClusterClient; @@ -193,7 +202,7 @@ async fn fetch_an_integer() -> String { Support for the RedisJSON Module can be enabled by specifying "json" as a feature in your Cargo.toml. -`redis = { version = "0.29.0", features = ["json"] }` +`redis = { version = "0.30.0", features = ["json"] }` Then you can simply import the `JsonCommands` trait which will add the `json` commands to all Redis Connections (not to be confused with just `Commands` which only adds the default commands) diff --git a/config/redis.toml b/config/redis.toml index f148be902..e3f0e1756 100644 --- a/config/redis.toml +++ b/config/redis.toml @@ -1,10 +1,10 @@ [[rule]] when = "cluster-async" -require = ["tokio-comp", "OR", "async-std-comp"] +require = ["tokio-comp", "OR", "smol-comp"] [[rule]] when = "connection-manager" -require = ["tokio-comp", "OR", "async-std-comp"] +require = ["tokio-comp", "OR", "smol-comp"] [[rule]] when = ["tls-rustls", "tokio-comp"] @@ -15,12 +15,12 @@ when = ["tls-native-tls", "tokio-comp"] require = ["tokio-native-tls-comp"] [[rule]] -when = ["tls-rustls", "async-std-comp"] -require = ["async-std-rustls-comp"] +when = ["tls-rustls", "smol-comp"] +require = ["smol-rustls-comp"] [[rule]] -when = ["tls-native-tls", "async-std-comp"] -require = ["async-std-native-tls-comp"] +when = ["tls-native-tls", "smol-comp"] +require = ["smol-native-tls-comp"] [[rule]] when = "tls-rustls-insecure" @@ -28,26 +28,59 @@ require = ["tls-rustls"] # the native-tls features shouldn't be tested with the rustls features [[rule]] -when = ["tokio-native-tls-comp", "OR", "async-std-native-tls-comp", "OR", "tls-native-tls"] -forbid = ["tokio-rustls-comp", "OR", "tls-rustls-webpki-roots", "OR", "tls-rustls", "OR", "tls-rustls-insecure", "OR", "async-std-rustls-comp"] - -# we don't need to check whether the async-std features are working with the tokio features -[[rule]] -when = ["async-std-native-tls-comp", "OR", "async-std-comp", "OR", "async-std-rustls-comp"] -forbid = ["tokio-rustls-comp", "OR", "tokio-native-tls-comp", "OR", "tokio-comp"] - -[[rule]] -when = ["async-std-native-tls-comp"] -forbid = ["async-std-rustls-comp"] - -[[rule]] -when = ["tls-rustls-webpki-roots", "async-std-comp"] -require = ["async-std-rustls-comp"] +when = [ + "tokio-native-tls-comp", + "OR", + "smol-native-tls-comp", + "OR", + "tls-native-tls", +] +forbid = [ + "tokio-rustls-comp", + "OR", + "tls-rustls-webpki-roots", + "OR", + "tls-rustls", + "OR", + "tls-rustls-insecure", + "OR", + "smol-rustls-comp", +] + +# we don't need to check whether the smol features are working with the tokio features +[[rule]] +when = [ + "smol-native-tls-comp", + "OR", + "smol-comp", + "OR", + "smol-rustls-comp", +] +forbid = [ + "tokio-rustls-comp", + "OR", + "tokio-native-tls-comp", + "OR", + "tokio-comp", +] + +[[rule]] +when = ["smol-native-tls-comp"] +forbid = ["smol-rustls-comp"] + +[[rule]] +when = ["tls-rustls-webpki-roots", "smol-comp"] +require = ["smol-rustls-comp"] [[rule]] when = ["tls-rustls-webpki-roots", "tokio-comp"] require = ["tokio-rustls-comp"] +# The users are expected to use only one connection pool. +[[rule]] +when = "r2d2" +forbid = "bb8" + # This feature can't run by itself [[rule]] when = true @@ -61,13 +94,31 @@ forbid = "tls" # deprecated [[rule]] when = true -forbid = "async-std-tls-comp" +forbid = [ + "async-std-tls-comp", + "OR", + "async-std-comp", + "OR", + "async-std-native-tls-comp", + "OR", + "async-std-rustls-comp", +] # these are all included in the `default` feature, so in order to reduce combinatoric explosion, we don't check them individually. [[rule]] when = true -forbid = ["acl", "OR", "streams", "OR", "geospatial", "OR", "script", "OR", "keep-alive"] +forbid = [ + "acl", + "OR", + "streams", + "OR", + "geospatial", + "OR", + "script", + "OR", + "keep-alive", +] [[rule]] when = "cache-aio" -require = ["tokio-comp", "OR", "async-std-comp"] +require = ["tokio-comp", "OR", "smol-comp"] diff --git a/redis-test/CHANGELOG.md b/redis-test/CHANGELOG.md index b0f2fb272..e2f14adeb 100644 --- a/redis-test/CHANGELOG.md +++ b/redis-test/CHANGELOG.md @@ -1,3 +1,6 @@ +### 0.10.0 (2025-04-22) +* Track redis 0.30.0 release + ### 0.9.0 (2025-02-16) * Track redis 0.29.0 release diff --git a/redis-test/Cargo.toml b/redis-test/Cargo.toml index a6552525c..b658e0ae6 100644 --- a/redis-test/Cargo.toml +++ b/redis-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "redis-test" -version = "0.9.0" +version = "0.10.0" edition = "2021" description = "Testing helpers for the `redis` crate" homepage = "https://github.com/redis-rs/redis-rs" @@ -13,10 +13,10 @@ rust-version = "1.75" bench = false [dependencies] -redis = { version = "0.29.0", path = "../redis" } +redis = { version = "0.30", path = "../redis" } bytes = { version = "1", optional = true } futures = { version = "0.3", optional = true } -tempfile = "=3.16.0" +tempfile = "=3.19.1" socket2 = "0.5" rand = "0.9" @@ -24,7 +24,7 @@ rand = "0.9" aio = ["futures", "redis/aio"] [dev-dependencies] -redis = { version = "0.29.0", path = "../redis", features = [ +redis = { version = "0.30", path = "../redis", features = [ "aio", "tokio-comp", ] } diff --git a/redis-test/src/cluster.rs b/redis-test/src/cluster.rs index f1a51991e..b13eaa1f0 100644 --- a/redis-test/src/cluster.rs +++ b/redis-test/src/cluster.rs @@ -4,15 +4,17 @@ use tempfile::TempDir; use crate::{ server::{Module, RedisServer}, - utils::{build_keys_and_certs_for_tls, get_random_available_port, TlsFilePaths}, + utils::{build_keys_and_certs_for_tls_ext, get_random_available_port, TlsFilePaths}, }; pub struct RedisClusterConfiguration { pub num_nodes: u16, pub num_replicas: u16, pub modules: Vec, + pub tls_insecure: bool, pub mtls_enabled: bool, pub ports: Vec, + pub certs_with_ip_alts: bool, } impl RedisClusterConfiguration { @@ -31,19 +33,22 @@ impl Default for RedisClusterConfiguration { num_nodes: 3, num_replicas: 0, modules: vec![], + tls_insecure: true, mtls_enabled: false, ports: vec![], + certs_with_ip_alts: true, } } } -enum ClusterType { +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ClusterType { Tcp, TcpTls, } impl ClusterType { - fn get_intended() -> ClusterType { + pub fn get_intended() -> ClusterType { match env::var("REDISRS_SERVER_TYPE") .ok() .as_ref() @@ -103,8 +108,10 @@ impl RedisCluster { num_nodes: nodes, num_replicas: replicas, modules, + tls_insecure, mtls_enabled, ports, + certs_with_ip_alts, } = configuration; let optional_ports = if ports.is_empty() { @@ -127,7 +134,7 @@ impl RedisCluster { .prefix("redis") .tempdir() .expect("failed to create tempdir"); - let files = build_keys_and_certs_for_tls(&tempdir); + let files = build_keys_and_certs_for_tls_ext(&tempdir, certs_with_ip_alts); folders.push(tempdir); tls_paths = Some(files); is_tls = true; @@ -265,6 +272,9 @@ impl RedisCluster { cmd.arg(ca_crt); cmd.arg("--tls"); } + } else if !tls_insecure && tls_paths.is_some() { + let ca_crt = &tls_paths.as_ref().unwrap().ca_crt; + cmd.arg("--tls").arg("--cacert").arg(ca_crt); } else { cmd.arg("--tls").arg("--insecure"); } diff --git a/redis-test/src/utils.rs b/redis-test/src/utils.rs index 16b5a40e5..24babee6a 100644 --- a/redis-test/src/utils.rs +++ b/redis-test/src/utils.rs @@ -13,6 +13,10 @@ pub struct TlsFilePaths { } pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { + build_keys_and_certs_for_tls_ext(tempdir, true) +} + +pub fn build_keys_and_certs_for_tls_ext(tempdir: &TempDir, with_ip_alts: bool) -> TlsFilePaths { // Based on shell script in redis's server tests // https://github.com/redis/redis/blob/8c291b97b95f2e011977b522acf77ead23e26f55/utils/gen-test-certs.sh let ca_crt = tempdir.path().join("ca.crt"); @@ -43,7 +47,7 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { make_key(&redis_key, 2048); // Build CA Cert - process::Command::new("openssl") + let status = process::Command::new("openssl") .arg("req") .arg("-x509") .arg("-new") @@ -63,16 +67,39 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { .expect("failed to spawn openssl") .wait() .expect("failed to create CA cert"); + assert!( + status.success(), + "`openssl req` failed to create CA cert: {status}" + ); // Build x509v3 extensions file - fs::write( - &ext_file, - b"keyUsage = digitalSignature, keyEncipherment\n\ - subjectAltName = @alt_names\n\ - [alt_names]\n\ - IP.1 = 127.0.0.1\n", - ) - .expect("failed to create x509v3 extensions file"); + let ext = if with_ip_alts { + "\ + keyUsage = digitalSignature, keyEncipherment\n\ + subjectAltName = @alt_names\n\ + [alt_names]\n\ + IP.1 = 127.0.0.1\n\ + " + } else { + "\ + [req]\n\ + distinguished_name = req_distinguished_name\n\ + x509_extensions = v3_req\n\ + prompt = no\n\ + \n\ + [req_distinguished_name]\n\ + CN = localhost.example.com\n\ + \n\ + [v3_req]\n\ + basicConstraints = CA:FALSE\n\ + keyUsage = nonRepudiation, digitalSignature, keyEncipherment\n\ + subjectAltName = @alt_names\n\ + \n\ + [alt_names]\n\ + DNS.1 = localhost.example.com\n\ + " + }; + fs::write(&ext_file, ext).expect("failed to create x509v3 extensions file"); // Read redis key let mut key_cmd = process::Command::new("openssl") @@ -89,7 +116,8 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { .expect("failed to spawn openssl"); // build redis cert - process::Command::new("openssl") + let mut command2 = process::Command::new("openssl"); + command2 .arg("x509") .arg("-req") .arg("-sha256") @@ -103,18 +131,28 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { .arg("-days") .arg("365") .arg("-extfile") - .arg(&ext_file) + .arg(&ext_file); + if !with_ip_alts { + command2.arg("-extensions").arg("v3_req"); + } + let status2 = command2 .arg("-out") .arg(&redis_crt) .stdin(key_cmd.stdout.take().expect("should have stdout")) - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) .spawn() .expect("failed to spawn openssl") .wait() .expect("failed to create redis cert"); - key_cmd.wait().expect("failed to create redis key"); + let status = key_cmd.wait().expect("failed to create redis key"); + assert!( + status.success(), + "`openssl req` failed to create request for Redis cert: {status}" + ); + assert!( + status2.success(), + "`openssl x509` failed to create Redis cert: {status2}" + ); TlsFilePaths { redis_crt, diff --git a/redis/CHANGELOG.md b/redis/CHANGELOG.md index ec23d1b58..af2940b75 100644 --- a/redis/CHANGELOG.md +++ b/redis/CHANGELOG.md @@ -1,21 +1,93 @@ +### 0.30.0 (2025-04-22) + +#### Changes & Bug fixes + +* Add epoch for CacheManager ([#1583](https://github.com/redis-rs/redis-rs/pull/1583) by @altanozlu) +* **Breaking change** Add support for the Smol runtime. ([#1606](https://github.com/redis-rs/redis-rs/pull/1606) by @nihohit) +* Add support for hash field expiration commands ([#1611](https://github.com/redis-rs/redis-rs/pull/1611) by @StefanPalashev) +* **Breaking change** Remove deprecated aio::Connection. ([#1613](https://github.com/redis-rs/redis-rs/pull/1613) by @nihohit) + +#### Documentation & CI improvements + +* Reduce number of flag-frenzy checks and format file. ([#1608](https://github.com/redis-rs/redis-rs/pull/1608) by @nihohit) +* Fix `make docs` ([#1607](https://github.com/redis-rs/redis-rs/pull/1607) by @somechris) +* Fail CI on warnings in docs. ([#1609](https://github.com/redis-rs/redis-rs/pull/1609) by @nihohit) + +### 0.29.5 (2025-04-06) + +#### Changes & Bug fixes + +* Fix build on Windows. ([#1601](https://github.com/redis-rs/redis-rs/pull/1601) by @nihohit) + +### 0.29.4 (2025-04-06) + +#### Changes & Bug fixes + +* Add async dns resolver for cluster config and fix doc ([#1595](https://github.com/redis-rs/redis-rs/pull/1595) by @wiserfz) +* Fix error kind declaration on non-unix machines. ([#1598](https://github.com/redis-rs/redis-rs/pull/1598) by @nihohit) + + +### 0.29.3 (2025-04-04) + +#### Changes & Bug fixes + +* re-export socket2. ([#1573](https://github.com/redis-rs/redis-rs/pull/1573) by @nihohit) +* Add commands to flush database(s) ([#1576](https://github.com/redis-rs/redis-rs/pull/1576) by @somechris) +* Fix valkey(s) url schemes not able to be converted to connection infos; Add valkey+unix url scheme ([#1574](https://github.com/redis-rs/redis-rs/pull/1574) by @MarkusTieger) +* A spec compliant version, with less changed code. ([#1572](https://github.com/redis-rs/redis-rs/pull/1572) by @65001) +* Support bb8 for cluster client ([#1577](https://github.com/redis-rs/redis-rs/pull/1577) by @wiserfz) +* Support custom DNS resolver for async client ([#1581](https://github.com/redis-rs/redis-rs/pull/1581) by @wiserfz) +* Update danger_accept_invalid_hostnames for rustls 0.23.24 ([#1592](https://github.com/redis-rs/redis-rs/pull/1592) by @jorendorff) + +### 0.29.2 (2025-03-21) + +#### Changes & Bug fixes + +* Add Valkey URL scheme ([#1558](https://github.com/redis-rs/redis-rs/pull/1558) by @displexic) +* Fix unreachable error when parsing a nested tuple. ([#1562](https://github.com/redis-rs/redis-rs/pull/1562) by @nihohit) +* Remove PFCOUNT and PFMERGE from the list of illegal cluster pipeline commands (#1565) ([#1566](https://github.com/redis-rs/redis-rs/pull/1566) by @stepanmracek) +* Remove EVALSHA from the list of illegal cluster pipeline commands ([#1568](https://github.com/redis-rs/redis-rs/pull/1568) by @stepanmracek) +* feat: Add bb8 support for async client ([#1564](https://github.com/redis-rs/redis-rs/pull/1564) by @Xuanwo) +* Add cache support to ConnectionManager. ([#1567](https://github.com/redis-rs/redis-rs/pull/1567) by @nihohit) +* perf: Run reconnection attempts concurrently ([#1557](https://github.com/redis-rs/redis-rs/pull/1557) by @Marwes) + +### 0.29.1 (2025-03-01) + +#### Changes & Bug fixes + +* Update rustls-native-certs. ([#1498](https://github.com/redis-rs/redis-rs/pull/1498)) +* Async cluster connection: Move response timeout out. ([#1532](https://github.com/redis-rs/redis-rs/pull/1532)) +* Expose `Pipeline.len()` function as public ([#1539](https://github.com/redis-rs/redis-rs/pull/1539) by @Harry-Lees) +* Implement `danger_accept_invalid_hostnames` option. ([#1529](https://github.com/redis-rs/redis-rs/pull/1529) by @jorendorff) +* Timeout on queuing requests. ([#1552](https://github.com/redis-rs/redis-rs/pull/1552)) + +#### Documentation improvements + +* docs: Fix double quotes ([#1537](https://github.com/redis-rs/redis-rs/pull/1537) by @somechris) +* docs: added CLIENT SETINFO optional feature ([#1536](https://github.com/redis-rs/redis-rs/pull/1536) by @bourdeau) + +#### CI improvements + +* Run most cluster tests in secure mode when rustls is enabled ([#1534](https://github.com/redis-rs/redis-rs/pull/1534) by @jorendorff) + ### 0.29.0 (2025-02-16) #### Changes & Bug fixes -* Tweaks to rustls usage ([#1499] (https://github.com/redis-rs/redis-rs/pull/1499) by @djc) -* Add client side caching support for MultiplexedConnection ([#1296] (https://github.com/redis-rs/redis-rs/pull/1296) by @altanozlu) -* Add buffered write methods to RedisWrite ([#905] (https://github.com/redis-rs/redis-rs/pull/905) by @swwu) -* Include the reason for one connection failure if cluster connect fails ([#1497] (https://github.com/redis-rs/redis-rs/pull/1497) by @Marwes) -* Upgrade to rand 0.9 ([#1525] (https://github.com/redis-rs/redis-rs/pull/1525) by @gkorland) -* Allow configuring Sentinel with custom TLS certificates ([#1335] (https://github.com/redis-rs/redis-rs/pull/1335) by @ergonjomeier) -* Fix caching with the JSON module ([#1520] (https://github.com/redis-rs/redis-rs/pull/1520) by @kudlatyamroth) -* Allow users of async connections to set TCP settings. ([#1523] (https://github.com/redis-rs/redis-rs/pull/1523)) +* Tweaks to rustls usage ([#1499](https://github.com/redis-rs/redis-rs/pull/1499) by @djc) +* Add client side caching support for MultiplexedConnection ([#1296](https://github.com/redis-rs/redis-rs/pull/1296) by @altanozlu) +* Add buffered write methods to RedisWrite ([#905](https://github.com/redis-rs/redis-rs/pull/905) by @swwu) +* Include the reason for one connection failure if cluster connect fails ([#1497](https://github.com/redis-rs/redis-rs/pull/1497) by @Marwes) +* Upgrade to rand 0.9 ([#1525](https://github.com/redis-rs/redis-rs/pull/1525) by @gkorland) +* Allow configuring Sentinel with custom TLS certificates ([#1335](https://github.com/redis-rs/redis-rs/pull/1335) by @ergonjomeier) +* Fix caching with the JSON module ([#1520](https://github.com/redis-rs/redis-rs/pull/1520) by @kudlatyamroth) +* Allow users of async connections to set TCP settings. ([#1523](https://github.com/redis-rs/redis-rs/pull/1523)) #### Documentation improvements -* Clarify build instructions in README. ([#1515] (https://github.com/redis-rs/redis-rs/pull/1515)) -* Improve pubsub docs. ([#1519] (https://github.com/redis-rs/redis-rs/pull/1519)) -* Improve docs around TTLs/ expiry times. ([#1522] (https://github.com/redis-rs/redis-rs/pull/1522) by @clbarnes) +* Clarify build instructions in README. ([#1515](https://github.com/redis-rs/redis-rs/pull/1515)) +* Improve pubsub docs. ([#1519](https://github.com/redis-rs/redis-rs/pull/1519)) +* Improve docs around TTLs/ expiry times. ([#1522](https://github.com/redis-rs/redis-rs/pull/1522) by @clbarnes) ### 0.28.2 (2025-01-24) diff --git a/redis/Cargo.toml b/redis/Cargo.toml index cef4eb97d..a9e2a08bf 100644 --- a/redis/Cargo.toml +++ b/redis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "redis" -version = "0.29.0" +version = "0.30.0" keywords = ["redis", "valkey", "cluster", "sentinel", "pubsub"] description = "Redis driver for Rust." homepage = "https://github.com/redis-rs/redis-rs" @@ -38,6 +38,7 @@ combine = { version = "4.6", default-features = false, features = ["std"] } # Only needed for AIO bytes = { version = "1", optional = true } +cfg-if = { version = "1", optional = true } futures-util = { version = "0.3.31", default-features = false, features = [ "std", "sink", @@ -55,11 +56,14 @@ socket2 = { version = "0.5", features = ["all"] } # Only needed for the connection manager arc-swap = { version = "1.7.1" } futures-channel = { version = "0.3.31", optional = true } -backon = { version = "1.3.0", optional = true, default-features = false } +backon = { version = "1.4.1", optional = true, default-features = false } # Only needed for the r2d2 feature r2d2 = { version = "0.8.10", optional = true } +# Only needed for the bb8 feature +bb8 = { version = "0.9.0", optional = true } + # Only needed for cluster crc16 = { version = "0.4", optional = true } rand = { version = "0.9", optional = true } @@ -70,6 +74,11 @@ futures-sink = { version = "0.3.31", optional = true } # Only needed for async_std support async-std = { version = "1.13.0", optional = true } +#only needed for smol support +smol = { version = "2", optional = true } +async-io = { version = "2", optional = true } +smol-timeout = { version = "0.6", optional = true } + # Only needed for native tls native-tls = { version = "0.2", optional = true } tokio-native-tls = { version = "0.3", optional = true } @@ -78,13 +87,13 @@ async-native-tls = { version = "0.5", optional = true } # Only needed for rustls rustls = { version = "0.23", optional = true, default-features = false } webpki-roots = { version = "0.26", optional = true } -rustls-native-certs = { version = "0.7", optional = true } +rustls-native-certs = { version = "0.8", optional = true } tokio-rustls = { version = "0.26", optional = true, default-features = false } futures-rustls = { version = "0.26", optional = true, default-features = false } # Only needed for RedisJSON Support -serde = { version = "1.0.217", optional = true } -serde_json = { version = "1.0.138", optional = true } +serde = { version = "1.0.218", optional = true } +serde_json = { version = "1.0.139", optional = true } # Only needed for bignum Support rust_decimal = { version = "1.36.0", optional = true } @@ -97,7 +106,7 @@ ahash = { version = "0.8.11", optional = true } log = { version = "0.4", optional = true } # Optional uuid support -uuid = { version = "1.12.1", optional = true } +uuid = { version = "1.15.1", optional = true } # Optional hashbrown support hashbrown = { version = "0.15", optional = true } @@ -120,13 +129,13 @@ tls-rustls = [ ] tls-rustls-insecure = ["tls-rustls"] tls-rustls-webpki-roots = ["tls-rustls", "dep:webpki-roots"] -async-std-comp = ["aio", "dep:async-std"] -async-std-native-tls-comp = [ - "async-std-comp", +smol-comp = ["aio", "dep:smol", "dep:smol-timeout", "dep:async-io"] +smol-native-tls-comp = [ + "smol-comp", "dep:async-native-tls", "tls-native-tls", ] -async-std-rustls-comp = ["async-std-comp", "dep:futures-rustls", "tls-rustls"] +smol-rustls-comp = ["smol-comp", "dep:futures-rustls", "tls-rustls"] tokio-comp = ["aio", "tokio/net"] tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "dep:tokio-native-tls"] tokio-rustls-comp = ["tokio-comp", "tls-rustls", "dep:tokio-rustls"] @@ -139,12 +148,21 @@ tcp_nodelay = [] num-bigint = [] disable-client-setinfo = [] cache-aio = ["aio", "dep:lru"] +r2d2 = ["dep:r2d2"] +bb8 = ["dep:bb8"] # Deprecated features tls = ["tls-native-tls"] # use "tls-native-tls" instead async-std-tls-comp = [ "async-std-native-tls-comp", ] # use "async-std-native-tls-comp" instead +async-std-comp = ["aio", "dep:async-std"] +async-std-native-tls-comp = [ + "async-std-comp", + "dep:async-native-tls", + "tls-native-tls", +] +async-std-rustls-comp = ["async-std-comp", "dep:futures-rustls", "tls-rustls"] # Instead of specifying "aio", use either "tokio-comp" or "async-std-comp". aio = [ "bytes", @@ -155,6 +173,7 @@ aio = [ "dep:tokio-util", "tokio-util/codec", "combine/tokio", + "dep:cfg-if", ] [dev-dependencies] @@ -172,11 +191,12 @@ tokio = { version = "1", features = [ "test-util", "time", ] } -tempfile = "=3.16.0" +tempfile = "=3.19.1" once_cell = "1" anyhow = "1" redis-test = { path = "../redis-test" } rstest = "0.24" +rand = "0.9" [[test]] name = "test_async" diff --git a/redis/examples/async-connection-loss.rs b/redis/examples/async-connection-loss.rs index 557fa4de6..dd221f834 100644 --- a/redis/examples/async-connection-loss.rs +++ b/redis/examples/async-connection-loss.rs @@ -16,22 +16,10 @@ use redis::RedisResult; use tokio::time::interval; enum Mode { - Deprecated, Default, Reconnect, } -async fn run_single(mut con: C) -> RedisResult<()> { - let mut interval = interval(Duration::from_millis(100)); - loop { - interval.tick().await; - println!(); - println!("> PING"); - let result: RedisResult = redis::cmd("PING").query_async(&mut con).await; - println!("< {result:?}"); - } -} - async fn run_multi(mut con: C) -> RedisResult<()> { let mut interval = interval(Duration::from_millis(100)); loop { @@ -67,10 +55,6 @@ async fn main() -> RedisResult<()> { println!("Using reconnect manager mode\n"); Mode::Reconnect } - Some("deprecated") => { - println!("Using deprecated connection mode\n"); - Mode::Deprecated - } Some(_) | None => { println!("Usage: reconnect-manager (default|multiplexed|reconnect)"); process::exit(1); @@ -81,8 +65,6 @@ async fn main() -> RedisResult<()> { match mode { Mode::Default => run_multi(client.get_multiplexed_async_connection().await?).await?, Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, - #[allow(deprecated)] - Mode::Deprecated => run_single(client.get_async_connection().await?).await?, }; Ok(()) } diff --git a/redis/src/aio/async_std.rs b/redis/src/aio/async_std.rs index d9d2872c6..6e61496d8 100644 --- a/redis/src/aio/async_std.rs +++ b/redis/src/aio/async_std.rs @@ -1,6 +1,6 @@ #[cfg(unix)] use std::path::Path; -#[cfg(feature = "tls-rustls")] +#[cfg(feature = "async-std-rustls-comp")] use std::sync::Arc; use std::{ future::Future, @@ -13,12 +13,15 @@ use std::{ use crate::aio::{AsyncStream, RedisRuntime}; use crate::types::RedisResult; -#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +#[cfg(all( + feature = "async-std-native-tls-comp", + not(feature = "async-std-rustls-comp") +))] use async_native_tls::{TlsConnector, TlsStream}; -#[cfg(feature = "tls-rustls")] +#[cfg(feature = "async-std-rustls-comp")] use crate::connection::create_rustls_config; -#[cfg(feature = "tls-rustls")] +#[cfg(feature = "async-std-rustls-comp")] use futures_rustls::{client::TlsStream, TlsConnector}; use super::TaskHandle; @@ -39,10 +42,11 @@ async fn connect_tcp( Ok(std_socket.into()) } -#[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; -#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +#[cfg(any( + feature = "async-std-rustls-comp", + feature = "async-std-native-tls-comp" +))] use crate::connection::TlsConnParams; pin_project_lite::pin_project! { @@ -192,12 +196,15 @@ impl RedisRuntime for AsyncStd { .map(|con| Self::Tcp(AsyncStdWrapped::new(con)))?) } - #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + #[cfg(all( + feature = "async-std-native-tls-comp", + not(feature = "async-std-rustls-comp") + ))] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, insecure: bool, - _tls_params: &Option, + tls_params: &Option, tcp_settings: &crate::io::tcp::TcpSettings, ) -> RedisResult { let tcp_stream = connect_tcp(&socket_addr, tcp_settings).await?; @@ -206,6 +213,9 @@ impl RedisRuntime for AsyncStd { .danger_accept_invalid_certs(true) .danger_accept_invalid_hostnames(true) .use_sni(false) + } else if let Some(params) = tls_params { + TlsConnector::new() + .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames) } else { TlsConnector::new() }; @@ -215,7 +225,7 @@ impl RedisRuntime for AsyncStd { .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) } - #[cfg(feature = "tls-rustls")] + #[cfg(feature = "async-std-rustls-comp")] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, diff --git a/redis/src/aio/connection.rs b/redis/src/aio/connection.rs index 360796f05..d9673f1a9 100644 --- a/redis/src/aio/connection.rs +++ b/redis/src/aio/connection.rs @@ -1,469 +1,30 @@ -#![allow(deprecated)] +use super::AsyncDNSResolver; +use super::RedisRuntime; -#[cfg(feature = "async-std-comp")] -use super::async_std; -use super::ConnectionLike; -use super::{setup_connection, AsyncStream, RedisRuntime}; -use crate::cmd::{cmd, Cmd}; -use crate::connection::{ - resp2_is_pub_sub_state_cleared, resp3_is_pub_sub_state_cleared, ConnectionAddr, ConnectionInfo, - Msg, RedisConnectionInfo, -}; +use crate::connection::{ConnectionAddr, ConnectionInfo}; use crate::io::tcp::TcpSettings; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] -use crate::parser::ValueCodec; -use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value}; -use crate::{from_owned_redis_value, ProtocolVersion, ToRedisArgs}; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use ::async_std::net::ToSocketAddrs; -use ::tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -#[cfg(feature = "tokio-comp")] -use ::tokio::net::lookup_host; -use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset}; -use futures_util::future::select_ok; -use futures_util::{ - future::FutureExt, - stream::{Stream, StreamExt}, -}; -use std::net::SocketAddr; -use std::pin::Pin; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] -use tokio_util::codec::Decoder; - -/// Represents a stateful redis TCP connection. -#[deprecated(note = "aio::Connection is deprecated. Use aio::MultiplexedConnection instead.")] -pub struct Connection>> { - con: C, - buf: Vec, - decoder: combine::stream::Decoder>, - db: i64, +#[cfg(feature = "aio")] +use crate::types::RedisResult; - // Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. - // - // This flag is checked when attempting to send a command, and if it's raised, we attempt to - // exit the pubsub state before executing the new request. - pubsub: bool, - - // Field indicating which protocol to use for server communications. - protocol: ProtocolVersion, -} +use futures_util::future::select_ok; -fn assert_sync() {} +const fn assert_sync() {} #[allow(unused)] -fn test() { - assert_sync::(); -} - -impl Connection { - pub(crate) fn map(self, f: impl FnOnce(C) -> D) -> Connection { - let Self { - con, - buf, - decoder, - db, - pubsub, - protocol, - } = self; - Connection { - con: f(con), - buf, - decoder, - db, - pubsub, - protocol, - } - } -} - -impl Connection -where - C: Unpin + AsyncRead + AsyncWrite + Send, -{ - /// Constructs a new `Connection` out of a `AsyncRead + AsyncWrite` object - /// and a `RedisConnectionInfo` - pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { - let mut rv = Connection { - con, - buf: Vec::new(), - decoder: combine::stream::Decoder::new(), - db: connection_info.db, - pubsub: false, - protocol: connection_info.protocol, - }; - setup_connection( - connection_info, - &mut rv, - #[cfg(feature = "cache-aio")] - None, - ) - .await?; - Ok(rv) - } - - /// Converts this [`Connection`] into [`PubSub`]. - #[deprecated(note = "aio::Connection is deprecated. Use [Client::get_async_pubsub] instead")] - pub fn into_pubsub(self) -> PubSub { - PubSub::new(self) - } - - /// Converts this [`Connection`] into [`Monitor`] - #[deprecated(note = "aio::Connection is deprecated. Use [Client::get_async_pubsub] instead")] - pub fn into_monitor(self) -> Monitor { - Monitor::new(self) - } - - /// Fetches a single response from the connection. - async fn read_response(&mut self) -> RedisResult { - crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await - } - - /// Brings [`Connection`] out of `PubSub` mode. - /// - /// This will unsubscribe this [`Connection`] from all subscriptions. - /// - /// If this function returns error then on all command send tries will be performed attempt - /// to exit from `PubSub` mode until it will be successful. - async fn exit_pubsub(&mut self) -> RedisResult<()> { - let res = self.clear_active_subscriptions().await; - if res.is_ok() { - self.pubsub = false; - } else { - // Raise the pubsub flag to indicate the connection is "stuck" in that state. - self.pubsub = true; - } - - res - } - - /// Get the inner connection out of a PubSub - /// - /// Any active subscriptions are unsubscribed. In the event of an error, the connection is - /// dropped. - async fn clear_active_subscriptions(&mut self) -> RedisResult<()> { - // Responses to unsubscribe commands return in a 3-tuple with values - // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). - // The "count of remaining subs" includes both pattern subscriptions and non pattern - // subscriptions. Thus, to accurately drain all unsubscribe messages received from the - // server, both commands need to be executed at once. - { - // Prepare both unsubscribe commands - let unsubscribe = crate::Pipeline::new() - .add_command(cmd("UNSUBSCRIBE")) - .add_command(cmd("PUNSUBSCRIBE")) - .get_packed_pipeline(); - - // Execute commands - self.con.write_all(&unsubscribe).await?; - } - - // Receive responses - // - // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe - // commands. There may be more responses if there are active subscriptions. In this case, - // messages are received until the _subscription count_ in the responses reach zero. - let mut received_unsub = false; - let mut received_punsub = false; - if self.protocol != ProtocolVersion::RESP2 { - while let Value::Push { kind, data } = - from_owned_redis_value(self.read_response().await?)? - { - if data.len() >= 2 { - if let Value::Int(num) = data[1] { - if resp3_is_pub_sub_state_cleared( - &mut received_unsub, - &mut received_punsub, - &kind, - num as isize, - ) { - break; - } - } - } - } - } else { - loop { - let res: (Vec, (), isize) = - from_owned_redis_value(self.read_response().await?)?; - if resp2_is_pub_sub_state_cleared( - &mut received_unsub, - &mut received_punsub, - &res.0, - res.2, - ) { - break; - } - } - } - - // Finally, the connection is back in its normal state since all subscriptions were - // cancelled *and* all unsubscribe messages were received. - Ok(()) - } -} - -#[cfg(feature = "async-std-comp")] -#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] -impl Connection> -where - C: Unpin + ::async_std::io::Read + ::async_std::io::Write + Send, -{ - /// Constructs a new `Connection` out of a `async_std::io::AsyncRead + async_std::io::AsyncWrite` object - /// and a `RedisConnectionInfo` - pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { - Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await - } -} - -pub(crate) async fn connect(connection_info: &ConnectionInfo) -> RedisResult> -where - C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send, -{ - let con = connect_simple::(connection_info, &TcpSettings::default()).await?; - Connection::new(&connection_info.redis, con).await -} - -impl ConnectionLike for Connection -where - C: Unpin + AsyncRead + AsyncWrite + Send, -{ - fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { - (async move { - if self.pubsub { - self.exit_pubsub().await?; - } - self.buf.clear(); - cmd.write_packed_command(&mut self.buf); - self.con.write_all(&self.buf).await?; - if cmd.is_no_response() { - return Ok(Value::Nil); - } - loop { - match self.read_response().await? { - Value::Push { .. } => continue, - val => return Ok(val), - } - } - }) - .boxed() - } - - fn req_packed_commands<'a>( - &'a mut self, - cmd: &'a crate::Pipeline, - offset: usize, - count: usize, - ) -> RedisFuture<'a, Vec> { - (async move { - if self.pubsub { - self.exit_pubsub().await?; - } - - self.buf.clear(); - cmd.write_packed_pipeline(&mut self.buf); - self.con.write_all(&self.buf).await?; - - let mut first_err = None; - - for _ in 0..offset { - let response = self.read_response().await; - match response { - Ok(Value::ServerError(err)) => { - if first_err.is_none() { - first_err = Some(err.into()); - } - } - Err(err) => { - if first_err.is_none() { - first_err = Some(err); - } - } - _ => {} - } - } - - let mut rv = Vec::with_capacity(count); - let mut count = count; - let mut idx = 0; - while idx < count { - let response = self.read_response().await; - match response { - Ok(item) => { - // RESP3 can insert push data between command replies - if let Value::Push { .. } = item { - // if that is the case we have to extend the loop and handle push data - count += 1; - } else { - rv.push(item); - } - } - Err(err) => { - if first_err.is_none() { - first_err = Some(err); - } - } - } - idx += 1; - } - - if let Some(err) = first_err { - Err(err) - } else { - Ok(rv) - } - }) - .boxed() - } - - fn get_db(&self) -> i64 { - self.db - } -} - -/// Represents a `PubSub` connection. -pub struct PubSub>>(Connection); - -/// Represents a `Monitor` connection. -pub struct Monitor>>(Connection); - -impl PubSub -where - C: Unpin + AsyncRead + AsyncWrite + Send, -{ - fn new(con: Connection) -> Self { - Self(con) - } - - /// Subscribes to a new channel(s). - pub async fn subscribe(&mut self, channel: T) -> RedisResult<()> { - let mut cmd = cmd("SUBSCRIBE"); - cmd.arg(channel); - if self.0.protocol != ProtocolVersion::RESP2 { - cmd.set_no_response(true); - } - cmd.query_async(&mut self.0).await - } - - /// Subscribes to new channel(s) with pattern(s). - pub async fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { - let mut cmd = cmd("PSUBSCRIBE"); - cmd.arg(pchannel); - if self.0.protocol != ProtocolVersion::RESP2 { - cmd.set_no_response(true); - } - cmd.query_async(&mut self.0).await - } - - /// Unsubscribes from a channel. - pub async fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { - let mut cmd = cmd("UNSUBSCRIBE"); - cmd.arg(channel); - if self.0.protocol != ProtocolVersion::RESP2 { - cmd.set_no_response(true); - } - cmd.query_async(&mut self.0).await - } - - /// Unsubscribes from channel pattern(s). - pub async fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { - let mut cmd = cmd("PUNSUBSCRIBE"); - cmd.arg(pchannel); - if self.0.protocol != ProtocolVersion::RESP2 { - cmd.set_no_response(true); - } - cmd.query_async(&mut self.0).await - } - - /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions. - /// - /// The message itself is still generic and can be converted into an appropriate type through - /// the helper methods on it. - pub fn on_message(&mut self) -> impl Stream + '_ { - ValueCodec::default() - .framed(&mut self.0.con) - .filter_map(|msg| Box::pin(async move { Msg::from_owned_value(msg.ok()?) })) - } - - /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it. - /// - /// The message itself is still generic and can be converted into an appropriate type through - /// the helper methods on it. - /// This can be useful in cases where the stream needs to be returned or held by something other - /// than the [`PubSub`]. - pub fn into_on_message(self) -> impl Stream { - ValueCodec::default() - .framed(self.0.con) - .filter_map(|msg| Box::pin(async move { Msg::from_owned_value(msg.ok()?) })) - } - - /// Exits from `PubSub` mode and converts [`PubSub`] into [`Connection`]. - #[deprecated(note = "aio::Connection is deprecated")] - pub async fn into_connection(mut self) -> Connection { - self.0.exit_pubsub().await.ok(); - - self.0 - } -} - -impl Monitor -where - C: Unpin + AsyncRead + AsyncWrite + Send, -{ - /// Create a [`Monitor`] from a [`Connection`] - pub fn new(con: Connection) -> Self { - Self(con) - } - - /// Deliver the MONITOR command to this [`Monitor`]ing wrapper. - pub async fn monitor(&mut self) -> RedisResult<()> { - cmd("MONITOR").query_async(&mut self.0).await - } - - /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection - pub fn on_message(&mut self) -> impl Stream + '_ { - ValueCodec::default() - .framed(&mut self.0.con) - .filter_map(|value| { - Box::pin(async move { T::from_owned_redis_value(value.ok()?).ok() }) - }) - } - - /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection - pub fn into_on_message(self) -> impl Stream { - ValueCodec::default() - .framed(self.0.con) - .filter_map(|value| { - Box::pin(async move { T::from_owned_redis_value(value.ok()?).ok() }) - }) - } -} - -async fn get_socket_addrs( - host: &str, - port: u16, -) -> RedisResult + Send + '_> { - #[cfg(feature = "tokio-comp")] - let socket_addrs = lookup_host((host, port)).await?; - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - let socket_addrs = (host, port).to_socket_addrs().await?; - - let mut socket_addrs = socket_addrs.peekable(); - match socket_addrs.peek() { - Some(_) => Ok(socket_addrs), - None => Err(RedisError::from(( - ErrorKind::InvalidClientConfig, - "No address found for host", - ))), - } +fn test_is_sync() { + assert_sync::(); + assert_sync::(); + assert_sync::(); } pub(crate) async fn connect_simple( connection_info: &ConnectionInfo, + dns_resolver: &dyn AsyncDNSResolver, tcp_settings: &TcpSettings, ) -> RedisResult { Ok(match connection_info.addr { ConnectionAddr::Tcp(ref host, port) => { - let socket_addrs = get_socket_addrs(host, port).await?; + let socket_addrs = dns_resolver.resolve(host, port).await?; select_ok(socket_addrs.map(|addr| Box::pin(::connect_tcp(addr, tcp_settings)))) .await? .0 @@ -476,7 +37,7 @@ pub(crate) async fn connect_simple( insecure, ref tls_params, } => { - let socket_addrs = get_socket_addrs(host, port).await?; + let socket_addrs = dns_resolver.resolve(host, port).await?; select_ok(socket_addrs.map(|socket_addr| { Box::pin(::connect_tcp_tls( host, @@ -493,7 +54,7 @@ pub(crate) async fn connect_simple( #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] ConnectionAddr::TcpTls { .. } => { fail!(( - ErrorKind::InvalidClientConfig, + crate::types::ErrorKind::InvalidClientConfig, "Cannot connect to TCP with TLS without the tls feature" )); } @@ -503,11 +64,11 @@ pub(crate) async fn connect_simple( #[cfg(not(unix))] ConnectionAddr::Unix(_) => { - return Err(RedisError::from(( - ErrorKind::InvalidClientConfig, + fail!(( + crate::types::ErrorKind::InvalidClientConfig, "Cannot connect to unix sockets \ on this platform", - ))) + )) } }) } diff --git a/redis/src/aio/connection_manager.rs b/redis/src/aio/connection_manager.rs index d9ed21173..f3807c8ff 100644 --- a/redis/src/aio/connection_manager.rs +++ b/redis/src/aio/connection_manager.rs @@ -1,4 +1,6 @@ use super::{AsyncPushSender, HandleContainer, RedisFuture}; +#[cfg(feature = "cache-aio")] +use crate::caching::CacheManager; use crate::{ aio::{check_resp3, ConnectionLike, MultiplexedConnection, Runtime}, cmd, @@ -37,6 +39,8 @@ pub struct ConnectionManagerConfig { /// if true, the manager should resubscribe automatically to all pubsub channels after reconnect. resubscribe_automatically: bool, tcp_settings: crate::io::tcp::TcpSettings, + #[cfg(feature = "cache-aio")] + pub(crate) cache_config: Option, } impl std::fmt::Debug for ConnectionManagerConfig { @@ -51,9 +55,11 @@ impl std::fmt::Debug for ConnectionManagerConfig { push_sender, resubscribe_automatically, tcp_settings, + #[cfg(feature = "cache-aio")] + cache_config, } = &self; - f.debug_struct("ConnectionManagerConfig") - .field("exponent_base", &exponent_base) + let mut str = f.debug_struct("ConnectionManagerConfig"); + str.field("exponent_base", &exponent_base) .field("factor", &factor) .field("number_of_retries", &number_of_retries) .field("max_delay", &max_delay) @@ -68,8 +74,12 @@ impl std::fmt::Debug for ConnectionManagerConfig { &"not set" }, ) - .field("tcp_settings", &tcp_settings) - .finish() + .field("tcp_settings", &tcp_settings); + + #[cfg(feature = "cache-aio")] + str.field("cache_config", &cache_config); + + str.finish() } } @@ -176,6 +186,15 @@ impl ConnectionManagerConfig { ..self } } + + /// Set the cache behavior. + #[cfg(feature = "cache-aio")] + pub fn set_cache_config(self, cache_config: crate::caching::CacheConfig) -> Self { + Self { + cache_config: Some(cache_config), + ..self + } + } } impl Default for ConnectionManagerConfig { @@ -190,6 +209,8 @@ impl Default for ConnectionManagerConfig { push_sender: None, resubscribe_automatically: false, tcp_settings: Default::default(), + #[cfg(feature = "cache-aio")] + cache_config: None, } } } @@ -207,6 +228,8 @@ struct Internals { retry_strategy: ExponentialBuilder, connection_config: AsyncConnectionConfig, subscription_tracker: Option>, + #[cfg(feature = "cache-aio")] + cache_manager: Option, _task_handle: HandleContainer, } @@ -372,6 +395,15 @@ impl ConnectionManager { connection_config = connection_config.set_response_timeout(response_timeout); } connection_config = connection_config.set_tcp_settings(config.tcp_settings); + #[cfg(feature = "cache-aio")] + let cache_manager = config + .cache_config + .as_ref() + .map(|cache_config| CacheManager::new(*cache_config)); + #[cfg(feature = "cache-aio")] + if let Some(cache_manager) = cache_manager.as_ref() { + connection_config = connection_config.set_cache_manager(cache_manager.clone()); + } let (oneshot_sender, oneshot_receiver) = oneshot::channel(); let _task_handle = HandleContainer::new( @@ -407,6 +439,8 @@ impl ConnectionManager { retry_strategy, connection_config, subscription_tracker, + #[cfg(feature = "cache-aio")] + cache_manager, _task_handle, })); @@ -453,6 +487,15 @@ impl ConnectionManager { /// when the connection loss was detected. fn reconnect(&self, current: arc_swap::Guard>>) { let self_clone = self.clone(); + #[cfg(not(feature = "cache-aio"))] + let connection_config = self_clone.0.connection_config.clone(); + #[cfg(feature = "cache-aio")] + let mut connection_config = self_clone.0.connection_config.clone(); + #[cfg(feature = "cache-aio")] + if let Some(manager) = self.0.cache_manager.as_ref() { + let new_cache_manager = manager.clone_and_increase_epoch(); + connection_config = connection_config.set_cache_manager(new_cache_manager); + } let new_connection: SharedRedisFuture = async move { let additional_commands = match &self_clone.0.subscription_tracker { Some(subscription_tracker) => Some( @@ -463,10 +506,11 @@ impl ConnectionManager { ), None => None, }; + let con = Self::new_connection( &self_clone.0.client, self_clone.0.retry_strategy, - &self_clone.0.connection_config, + &connection_config, additional_commands, ) .await?; @@ -621,6 +665,13 @@ impl ConnectionManager { .await; Ok(()) } + + /// Gets [`crate::caching::CacheStatistics`] for current connection if caching is enabled. + #[cfg(feature = "cache-aio")] + #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))] + pub fn get_cache_statistics(&self) -> Option { + self.0.cache_manager.as_ref().map(|cm| cm.statistics()) + } } impl ConnectionLike for ConnectionManager { diff --git a/redis/src/aio/mod.rs b/redis/src/aio/mod.rs index 563449405..ba507d786 100644 --- a/redis/src/aio/mod.rs +++ b/redis/src/aio/mod.rs @@ -4,26 +4,35 @@ use crate::connection::{ check_connection_setup, connection_setup_pipeline, AuthResult, ConnectionSetupComponents, RedisConnectionInfo, }; -use crate::types::{RedisFuture, RedisResult, Value}; -use crate::PushInfo; +use crate::io::AsyncDNSResolver; +use crate::types::{closed_connection_error, RedisFuture, RedisResult, Value}; +use crate::{ErrorKind, PushInfo, RedisError}; use ::tokio::io::{AsyncRead, AsyncWrite}; -use futures_util::Future; +use futures_util::{ + future::{Future, FutureExt}, + sink::{Sink, SinkExt}, + stream::{Stream, StreamExt}, +}; +pub use monitor::Monitor; use std::net::SocketAddr; #[cfg(unix)] use std::path::Path; use std::pin::Pin; +mod monitor; + /// Enables the async_std compatibility #[cfg(feature = "async-std-comp")] #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] pub mod async_std; -#[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; - -#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +#[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))] use crate::connection::TlsConnParams; +/// Enables the smol compatibility +#[cfg(feature = "smol-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "smol-comp")))] +pub mod smol; /// Enables the tokio compatibility #[cfg(feature = "tokio-comp")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] @@ -93,27 +102,42 @@ pub trait ConnectionLike { fn get_db(&self) -> i64; } -async fn execute_connection_pipeline( - rv: &mut impl ConnectionLike, +async fn execute_connection_pipeline( + codec: &mut T, (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents), -) -> RedisResult { - if pipeline.len() == 0 { +) -> RedisResult +where + T: Sink, Error = RedisError>, + T: Stream>, + T: Unpin + Send + 'static, +{ + let count = pipeline.len(); + if count == 0 { return Ok(AuthResult::Succeeded); } + codec.send(pipeline.get_packed_pipeline()).await?; - let results = rv.req_packed_commands(&pipeline, 0, pipeline.len()).await?; + let mut results = Vec::with_capacity(count); + for _ in 0..count { + let value = codec.next().await.ok_or_else(closed_connection_error)??; + results.push(value); + } check_connection_setup(results, instructions) } -// Initial setup for every connection. -async fn setup_connection( +pub(super) async fn setup_connection( + codec: &mut T, connection_info: &RedisConnectionInfo, - con: &mut impl ConnectionLike, #[cfg(feature = "cache-aio")] cache_config: Option, -) -> RedisResult<()> { +) -> RedisResult<()> +where + T: Sink, Error = RedisError>, + T: Stream>, + T: Unpin + Send + 'static, +{ if execute_connection_pipeline( - con, + codec, connection_setup_pipeline( connection_info, true, @@ -125,7 +149,7 @@ async fn setup_connection( == AuthResult::ShouldRetryWithoutUsername { execute_connection_pipeline( - con, + codec, connection_setup_pipeline( connection_info, false, @@ -140,7 +164,7 @@ async fn setup_connection( } mod connection; -pub use connection::*; +pub(crate) use connection::connect_simple; mod multiplexed_connection; pub use multiplexed_connection::*; #[cfg(feature = "connection-manager")] @@ -149,6 +173,21 @@ mod connection_manager; #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] pub use connection_manager::*; mod runtime; +#[cfg(all( + feature = "async-std-comp", + any(feature = "smol-comp", feature = "tokio-comp") +))] +pub use runtime::prefer_async_std; +#[cfg(all( + feature = "smol-comp", + any(feature = "async-std-comp", feature = "tokio-comp") +))] +pub use runtime::prefer_smol; +#[cfg(all( + feature = "tokio-comp", + any(feature = "async-std-comp", feature = "smol-comp") +))] +pub use runtime::prefer_tokio; pub(super) use runtime::*; macro_rules! check_resp3 { @@ -230,3 +269,48 @@ where self.as_ref().send(info) } } + +/// Default DNS resolver which uses the system's DNS resolver. +#[derive(Clone)] +pub(crate) struct DefaultAsyncDNSResolver; + +impl AsyncDNSResolver for DefaultAsyncDNSResolver { + fn resolve<'a, 'b: 'a>( + &'a self, + host: &'b str, + port: u16, + ) -> RedisFuture<'a, Box + Send + 'a>> { + Box::pin(get_socket_addrs(host, port).map(|vec| { + Ok(Box::new(vec?.into_iter()) as Box + Send>) + })) + } +} + +async fn get_socket_addrs(host: &str, port: u16) -> RedisResult> { + let socket_addrs: Vec<_> = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => ::tokio::net::lookup_host((host, port)) + .await + .map_err(RedisError::from) + .map(|iter| iter.collect()), + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => Ok::<_, RedisError>( + ::async_std::net::ToSocketAddrs::to_socket_addrs(&(host, port)) + .await + .map(|iter| iter.collect())?, + ), + #[cfg(feature = "smol-comp")] + Runtime::Smol => ::smol::net::resolve((host, port)) + .await + .map_err(RedisError::from), + }?; + + if socket_addrs.is_empty() { + Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "No address found for host", + ))) + } else { + Ok(socket_addrs) + } +} diff --git a/redis/src/aio/monitor.rs b/redis/src/aio/monitor.rs new file mode 100644 index 000000000..31b81082e --- /dev/null +++ b/redis/src/aio/monitor.rs @@ -0,0 +1,109 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ready, SinkExt, Stream, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Decoder; + +use crate::{ + cmd, parser::ValueCodec, types::closed_connection_error, FromRedisValue, RedisConnectionInfo, + RedisResult, Value, +}; + +use super::setup_connection; + +/// Represents a `Monitor` connection. +pub struct Monitor { + stream: Box> + Send + Sync + Unpin>, +} + +impl Monitor { + pub(crate) async fn new( + connection_info: &RedisConnectionInfo, + stream: C, + ) -> RedisResult + where + C: Unpin + AsyncRead + AsyncWrite + Send + Sync + 'static, + { + let mut codec = ValueCodec::default().framed(stream); + setup_connection( + &mut codec, + connection_info, + #[cfg(feature = "cache-aio")] + None, + ) + .await?; + codec.send(cmd("MONITOR").get_packed_command()).await?; + codec.next().await.ok_or_else(closed_connection_error)??; + let stream = Box::new(codec); + + Ok(Self { stream }) + } + + /// Deliver the MONITOR command to this [`Monitor`]ing wrapper. + #[deprecated(note = "A monitor now sends the MONITOR command automatically")] + pub async fn monitor(&mut self) -> RedisResult<()> { + Ok(()) + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn on_message<'a, T: FromRedisValue + 'a>(&'a mut self) -> impl Stream + 'a { + MonitorStreamRef { + monitor: self, + _phantom: std::marker::PhantomData, + } + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn into_on_message(self) -> impl Stream { + MonitorStream { + stream: self.stream, + _phantom: std::marker::PhantomData, + } + } +} + +struct MonitorStream { + stream: Box> + Send + Sync + Unpin>, + _phantom: std::marker::PhantomData, +} +impl Unpin for MonitorStream {} + +fn convert_value(value: RedisResult) -> Option +where + T: FromRedisValue, +{ + value + .ok() + .and_then(|value| T::from_owned_redis_value(value).ok()) +} + +impl Stream for MonitorStream +where + T: FromRedisValue, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(ready!(self.stream.poll_next_unpin(cx)).and_then(convert_value)) + } +} + +struct MonitorStreamRef<'a, T> { + monitor: &'a mut Monitor, + _phantom: std::marker::PhantomData, +} +impl Unpin for MonitorStreamRef<'_, T> {} + +impl Stream for MonitorStreamRef<'_, T> +where + T: FromRedisValue, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(ready!(self.monitor.stream.poll_next_unpin(cx)).and_then(convert_value)) + } +} diff --git a/redis/src/aio/multiplexed_connection.rs b/redis/src/aio/multiplexed_connection.rs index 41fcbc178..c90cd3336 100644 --- a/redis/src/aio/multiplexed_connection.rs +++ b/redis/src/aio/multiplexed_connection.rs @@ -3,7 +3,6 @@ use crate::aio::{check_resp3, setup_connection}; #[cfg(feature = "cache-aio")] use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult}; use crate::cmd::Cmd; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use crate::parser::ValueCodec; use crate::types::{closed_connection_error, RedisError, RedisFuture, RedisResult, Value}; use crate::{ @@ -27,7 +26,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{self, Poll}; use std::time::Duration; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use tokio_util::codec::Decoder; // Senders which the result of a single request are sent through @@ -357,11 +355,9 @@ impl Pipeline { #[cfg(feature = "cache-aio")] cache_manager: Option, ) -> (Self, impl Future) where - T: Sink, Error = RedisError> + Stream> + 'static, - T: Send + 'static, - T::Item: Send, - T::Error: Send, - T::Error: ::std::fmt::Debug, + T: Sink, Error = RedisError>, + T: Stream>, + T: Unpin + Send + 'static, { const BUFFER_SIZE: usize = 50; let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE); @@ -386,29 +382,34 @@ impl Pipeline { // If `Some`, the value inside defines how the response should look like expectation: Option, timeout: Option, - ) -> Result> { + ) -> Result { let (sender, receiver) = oneshot::channel(); - self.sender - .send(PipelineMessage { - input, - expectation, - output: sender, - }) - .await - .map_err(|_| None)?; + let request = async { + self.sender + .send(PipelineMessage { + input, + expectation, + output: sender, + }) + .await + .map_err(|_| None)?; + + receiver.await + // The `sender` was dropped which likely means that the stream part + // failed for one reason or another + .map_err(|_| None) + .and_then(|res| res.map_err(Some)) + }; match timeout { - Some(timeout) => match Runtime::locate().timeout(timeout, receiver).await { + Some(timeout) => match Runtime::locate().timeout(timeout, request).await { Ok(res) => res, - Err(elapsed) => Ok(Err(elapsed.into())), + Err(elapsed) => Err(Some(elapsed.into())), }, - None => receiver.await, + None => request.await, } - // The `sender` was dropped which likely means that the stream part - // failed for one reason or another - .map_err(|_| None) - .and_then(|res| res.map_err(Some)) + .map_err(|err| err.unwrap_or_else(closed_connection_error)) } } @@ -491,10 +492,7 @@ impl MultiplexedConnection { where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, { - #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] - compile_error!("tokio-comp or async-std-comp features required for aio feature"); - - let codec = ValueCodec::default().framed(stream); + let mut codec = ValueCodec::default().framed(stream); if config.push_sender.is_some() { check_resp3!( connection_info.protocol, @@ -502,25 +500,51 @@ impl MultiplexedConnection { ); } + #[cfg(feature = "cache-aio")] + let cache_config = config.cache.as_ref().map(|cache| match cache { + crate::client::Cache::Config(cache_config) => *cache_config, + #[cfg(feature = "connection-manager")] + crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config, + }); #[cfg(feature = "cache-aio")] let cache_manager_opt = config - .cache_config - .map(|config| { + .cache + .map(|cache| { check_resp3!( connection_info.protocol, "Can only enable client side caching in a connection using RESP3" ); - Ok(CacheManager::new(config)) + match cache { + crate::client::Cache::Config(cache_config) => { + Ok(CacheManager::new(cache_config)) + } + #[cfg(feature = "connection-manager")] + crate::client::Cache::Manager(cache_manager) => Ok(cache_manager), + } }) .transpose()?; + setup_connection( + &mut codec, + connection_info, + #[cfg(feature = "cache-aio")] + cache_config, + ) + .await?; + if config.push_sender.is_some() { + check_resp3!( + connection_info.protocol, + "Can only pass push sender to a connection using RESP3" + ); + } + let (pipeline, driver) = Pipeline::new( codec, config.push_sender, #[cfg(feature = "cache-aio")] cache_manager_opt.clone(), ); - let mut con = MultiplexedConnection { + let con = MultiplexedConnection { pipeline, db: connection_info.db, response_timeout: config.response_timeout, @@ -529,29 +553,7 @@ impl MultiplexedConnection { #[cfg(feature = "cache-aio")] cache_manager: cache_manager_opt, }; - let driver = { - let auth = setup_connection( - connection_info, - &mut con, - #[cfg(feature = "cache-aio")] - config.cache_config, - ); - - futures_util::pin_mut!(auth); - match futures_util::future::select(auth, driver).await { - futures_util::future::Either::Left((result, driver)) => { - result?; - driver - } - futures_util::future::Either::Right(((), _)) => { - return Err(RedisError::from(( - crate::ErrorKind::IoError, - "Multiplexed connection driver unexpectedly terminated", - ))); - } - } - }; Ok((con, driver)) } @@ -588,8 +590,7 @@ impl MultiplexedConnection { }), self.response_timeout, ) - .await - .map_err(|err| err.unwrap_or_else(closed_connection_error))?; + .await?; let replies: Vec = crate::types::from_owned_redis_value(result)?; return cacheable_command.resolve(cache_manager, replies.into_iter()); } @@ -599,7 +600,6 @@ impl MultiplexedConnection { self.pipeline .send_recv(cmd.get_packed_command(), None, self.response_timeout) .await - .map_err(|err| err.unwrap_or_else(closed_connection_error)) } /// Sends multiple already encoded (packed) command into the TCP socket @@ -626,12 +626,11 @@ impl MultiplexedConnection { }), self.response_timeout, ) - .await - .map_err(|err| err.unwrap_or_else(closed_connection_error))?; + .await?; return cacheable_pipeline.resolve(cache_manager, result); } - let result = self + let value = self .pipeline .send_recv( cmd.get_packed_pipeline(), @@ -642,10 +641,7 @@ impl MultiplexedConnection { }), self.response_timeout, ) - .await - .map_err(|err| err.unwrap_or_else(closed_connection_error)); - - let value = result?; + .await?; match value { Value::Array(values) => Ok(values), _ => Ok(vec![value]), @@ -687,6 +683,7 @@ impl MultiplexedConnection { /// /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise. /// + /// ```rust,no_run /// # async fn func() -> redis::RedisResult<()> { /// let client = redis::Client::open("redis://127.0.0.1/?protocol=resp3").unwrap(); /// let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -694,7 +691,7 @@ impl MultiplexedConnection { /// let mut con = client.get_multiplexed_async_connection_with_config(&config).await?; /// con.subscribe(&["channel_1", "channel_2"]).await?; /// # Ok(()) } - /// # } + /// ``` pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> { check_resp3!(self.protocol); let mut cmd = cmd("SUBSCRIBE"); @@ -707,6 +704,7 @@ impl MultiplexedConnection { /// /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise. /// + /// ```rust,no_run /// # async fn func() -> redis::RedisResult<()> { /// let client = redis::Client::open("redis://127.0.0.1/?protocol=resp3").unwrap(); /// let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -715,7 +713,7 @@ impl MultiplexedConnection { /// con.subscribe(&["channel_1", "channel_2"]).await?; /// con.unsubscribe(&["channel_1", "channel_2"]).await?; /// # Ok(()) } - /// # } + /// ``` pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> { check_resp3!(self.protocol); let mut cmd = cmd("UNSUBSCRIBE"); @@ -731,6 +729,7 @@ impl MultiplexedConnection { /// /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise. /// + /// ```rust,no_run /// # async fn func() -> redis::RedisResult<()> { /// let client = redis::Client::open("redis://127.0.0.1/?protocol=resp3").unwrap(); /// let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -739,7 +738,7 @@ impl MultiplexedConnection { /// con.subscribe(&["channel_1", "channel_2"]).await?; /// con.unsubscribe(&["channel_1", "channel_2"]).await?; /// # Ok(()) } - /// # } + /// ``` pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> { check_resp3!(self.protocol); let mut cmd = cmd("PSUBSCRIBE"); diff --git a/redis/src/aio/pubsub.rs b/redis/src/aio/pubsub.rs index 17982257d..f5051882f 100644 --- a/redis/src/aio/pubsub.rs +++ b/redis/src/aio/pubsub.rs @@ -1,8 +1,4 @@ use crate::aio::Runtime; -use crate::connection::{ - check_connection_setup, connection_setup_pipeline, AuthResult, ConnectionSetupComponents, -}; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use crate::parser::ValueCodec; use crate::types::{closed_connection_error, RedisError, RedisResult, Value}; use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs}; @@ -13,7 +9,7 @@ use ::tokio::{ use futures_util::{ future::{Future, FutureExt}, ready, - sink::{Sink, SinkExt}, + sink::Sink, stream::{self, Stream, StreamExt}, }; use pin_project_lite::pin_project; @@ -22,10 +18,9 @@ use std::pin::Pin; use std::task::{self, Poll}; use tokio::sync::mpsc::unbounded_channel; use tokio::sync::mpsc::UnboundedSender; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use tokio_util::codec::Decoder; -use super::SharedHandleContainer; +use super::{setup_connection, SharedHandleContainer}; // A signal that a un/subscribe request has completed. type RequestResultSender = oneshot::Sender>; @@ -250,10 +245,9 @@ impl PubSubSink { messages_sender: UnboundedSender, ) -> (Self, impl Future) where - T: Sink, Error = RedisError> + Stream> + Send + 'static, - T::Item: Send, - T::Error: Send, - T::Error: ::std::fmt::Debug, + T: Sink, Error = RedisError>, + T: Stream>, + T: Unpin + Send + 'static, { let (sender, mut receiver) = unbounded_channel(); let sink = PipelineSink::new(sink_stream, messages_sender); @@ -380,75 +374,6 @@ pub struct PubSub { stream: PubSubStream, } -async fn execute_connection_pipeline( - codec: &mut T, - (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents), -) -> RedisResult -where - T: Sink, Error = RedisError> + Stream> + 'static, - T: Send + 'static, - T::Item: Send, - T::Error: Send, - T::Error: ::std::fmt::Debug, - T: Unpin, -{ - let count = pipeline.len(); - if count == 0 { - return Ok(AuthResult::Succeeded); - } - codec.send(pipeline.get_packed_pipeline()).await?; - - let mut results = Vec::with_capacity(count); - for _ in 0..count { - let value = codec.next().await; - match value { - Some(Ok(val)) => results.push(val), - _ => return Err(closed_connection_error()), - } - } - - check_connection_setup(results, instructions) -} - -async fn setup_connection( - codec: &mut T, - connection_info: &RedisConnectionInfo, -) -> RedisResult<()> -where - T: Sink, Error = RedisError> + Stream> + 'static, - T: Send + 'static, - T::Item: Send, - T::Error: Send, - T::Error: ::std::fmt::Debug, - T: Unpin, -{ - if execute_connection_pipeline( - codec, - connection_setup_pipeline( - connection_info, - true, - #[cfg(feature = "cache-aio")] - None, - ), - ) - .await? - == AuthResult::ShouldRetryWithoutUsername - { - execute_connection_pipeline( - codec, - connection_setup_pipeline( - connection_info, - false, - #[cfg(feature = "cache-aio")] - None, - ), - ) - .await?; - } - - Ok(()) -} - impl PubSub { /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object /// and a `ConnectionInfo` @@ -456,11 +381,14 @@ impl PubSub { where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, { - #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] - compile_error!("tokio-comp or async-std-comp features required for aio feature"); - let mut codec = ValueCodec::default().framed(stream); - setup_connection(&mut codec, connection_info).await?; + setup_connection( + &mut codec, + connection_info, + #[cfg(feature = "cache-aio")] + None, + ) + .await?; let (sender, receiver) = unbounded_channel(); let (sink, driver) = PubSubSink::new(codec, sender); let handle = Runtime::locate().spawn(driver); diff --git a/redis/src/aio/runtime.rs b/redis/src/aio/runtime.rs index 4913d5f40..d6b1eb0b7 100644 --- a/redis/src/aio/runtime.rs +++ b/redis/src/aio/runtime.rs @@ -2,19 +2,41 @@ use std::{io, sync::Arc, time::Duration}; use futures_util::Future; +#[cfg(any( + all( + feature = "tokio-comp", + any(feature = "async-std-comp", feature = "smol-comp") + ), + all( + feature = "smol-comp", + any(feature = "async-std-comp", feature = "tokio-comp") + ), + all( + feature = "async-std-comp", + any(feature = "tokio-comp", feature = "smol-comp") + ) +))] +use std::sync::OnceLock; + #[cfg(feature = "async-std-comp")] use super::async_std as crate_async_std; +#[cfg(feature = "smol-comp")] +use super::smol as crate_smol; #[cfg(feature = "tokio-comp")] use super::tokio as crate_tokio; use super::RedisRuntime; use crate::types::RedisError; +#[cfg(feature = "smol-comp")] +use smol_timeout::TimeoutExt; -#[derive(Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub(crate) enum Runtime { #[cfg(feature = "tokio-comp")] Tokio, #[cfg(feature = "async-std-comp")] AsyncStd, + #[cfg(feature = "smol-comp")] + Smol, } pub(crate) enum TaskHandle { @@ -22,6 +44,8 @@ pub(crate) enum TaskHandle { Tokio(tokio::task::JoinHandle<()>), #[cfg(feature = "async-std-comp")] AsyncStd(async_std::task::JoinHandle<()>), + #[cfg(feature = "smol-comp")] + Smol(smol::Task<()>), } pub(crate) struct HandleContainer(Option); @@ -41,8 +65,11 @@ impl Drop for HandleContainer { #[cfg(feature = "async-std-comp")] Some(TaskHandle::AsyncStd(handle)) => { // schedule for cancellation without waiting for result. + // TODO - can we cancel the task without awaiting its completion? Runtime::locate().spawn(async move { handle.cancel().await.unwrap_or_default() }); } + #[cfg(feature = "smol-comp")] + Some(TaskHandle::Smol(task)) => drop(task), } } } @@ -58,30 +85,159 @@ impl SharedHandleContainer { } } +#[cfg(any( + all( + feature = "tokio-comp", + any(feature = "async-std-comp", feature = "smol-comp") + ), + all( + feature = "smol-comp", + any(feature = "async-std-comp", feature = "tokio-comp") + ), + all( + feature = "async-std-comp", + any(feature = "tokio-comp", feature = "smol-comp") + ) +))] +static CHOSEN_RUNTIME: OnceLock = OnceLock::new(); + +#[cfg(any( + all( + feature = "tokio-comp", + any(feature = "async-std-comp", feature = "smol-comp") + ), + all( + feature = "smol-comp", + any(feature = "async-std-comp", feature = "tokio-comp") + ), + all( + feature = "async-std-comp", + any(feature = "tokio-comp", feature = "smol-comp") + ) +))] +fn set_runtime(runtime: Runtime) -> Result<(), RedisError> { + const PREFER_RUNTIME_ERROR: &str = + "Another runtime preference was already set. Please call this function before any other runtime preference is set."; + + CHOSEN_RUNTIME + .set(runtime) + .map_err(|_| RedisError::from((crate::ErrorKind::ClientError, PREFER_RUNTIME_ERROR))) +} + +/// Mark Smol as the preferred runtime. +/// +/// If the function returns `Err`, another runtime preference was already set, and won't be changed. +/// Call this function if the application doesn't use multiple runtimes, +/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided. +#[cfg(all( + feature = "smol-comp", + any(feature = "async-std-comp", feature = "tokio-comp") +))] +pub fn prefer_smol() -> Result<(), RedisError> { + set_runtime(Runtime::Smol) +} + +/// Mark async-std compliant runtimes, such as smol, as the preferred runtime. +/// +/// If the function returns `Err`, another runtime preference was already set, and won't be changed. +/// Call this function if the application doesn't use multiple runtimes, +/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided. +#[cfg(all( + feature = "async-std-comp", + any(feature = "tokio-comp", feature = "smol-comp") +))] +pub fn prefer_async_std() -> Result<(), RedisError> { + set_runtime(Runtime::AsyncStd) +} + +/// Mark Tokio as the preferred runtime. +/// +/// If the function returns `Err`, another runtime preference was already set, and won't be changed. +/// Call this function if the application doesn't use multiple runtimes, +/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided. +#[cfg(all( + feature = "tokio-comp", + any(feature = "async-std-comp", feature = "smol-comp") +))] +pub fn prefer_tokio() -> Result<(), RedisError> { + set_runtime(Runtime::Tokio) +} + impl Runtime { pub(crate) fn locate() -> Self { - #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))] + #[cfg(any( + all( + feature = "tokio-comp", + any(feature = "async-std-comp", feature = "smol-comp") + ), + all( + feature = "smol-comp", + any(feature = "async-std-comp", feature = "tokio-comp") + ), + all( + feature = "async-std-comp", + any(feature = "tokio-comp", feature = "smol-comp") + ) + ))] + if let Some(runtime) = CHOSEN_RUNTIME.get() { + return *runtime; + } + + #[cfg(all( + feature = "tokio-comp", + not(feature = "async-std-comp"), + not(feature = "smol-comp") + ))] { Runtime::Tokio } - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + #[cfg(all( + not(feature = "tokio-comp"), + not(feature = "smol-comp"), + feature = "async-std-comp" + ))] { Runtime::AsyncStd } - #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(all( + not(feature = "tokio-comp"), + feature = "smol-comp", + not(feature = "async-std-comp") + ))] { + Runtime::Smol + } + + cfg_if::cfg_if! { + if #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] { if ::tokio::runtime::Handle::try_current().is_ok() { Runtime::Tokio } else { Runtime::AsyncStd } + } else if #[cfg(all(feature = "tokio-comp", feature = "smol-comp"))] { + if ::tokio::runtime::Handle::try_current().is_ok() { + Runtime::Tokio + } else { + Runtime::Smol + } + } else if #[cfg(all(feature = "smol-comp", feature = "async-std-comp"))] + { + Runtime::AsyncStd + } } - #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] + #[cfg(all( + not(feature = "tokio-comp"), + not(feature = "async-std-comp"), + not(feature = "smol-comp") + ))] { - compile_error!("tokio-comp or async-std-comp features required for aio feature") + compile_error!( + "tokio-comp, async-std-comp, or smol-comp features required for aio feature" + ) } } @@ -92,6 +248,8 @@ impl Runtime { Runtime::Tokio => crate_tokio::Tokio::spawn(f), #[cfg(feature = "async-std-comp")] Runtime::AsyncStd => crate_async_std::AsyncStd::spawn(f), + #[cfg(feature = "smol-comp")] + Runtime::Smol => crate_smol::Smol::spawn(f), } } @@ -109,6 +267,8 @@ impl Runtime { Runtime::AsyncStd => async_std::future::timeout(duration, future) .await .map_err(|_| Elapsed(())), + #[cfg(feature = "smol-comp")] + Runtime::Smol => future.timeout(duration).await.ok_or(Elapsed(())), } } @@ -123,6 +283,10 @@ impl Runtime { Runtime::AsyncStd => { async_std::task::sleep(duration).await; } + #[cfg(feature = "smol-comp")] + Runtime::Smol => { + smol::Timer::after(duration).await; + } } } diff --git a/redis/src/aio/smol.rs b/redis/src/aio/smol.rs new file mode 100644 index 000000000..b1742213f --- /dev/null +++ b/redis/src/aio/smol.rs @@ -0,0 +1,247 @@ +#[cfg(unix)] +use std::path::Path; +use std::sync::Arc; +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + task::{self, Poll}, +}; + +use crate::aio::{AsyncStream, RedisRuntime}; +use crate::types::RedisResult; + +#[cfg(all(feature = "smol-native-tls-comp", not(feature = "smol-rustls-comp")))] +use async_native_tls::{TlsConnector, TlsStream}; + +#[cfg(feature = "smol-rustls-comp")] +use crate::connection::create_rustls_config; +#[cfg(feature = "smol-rustls-comp")] +use futures_rustls::{client::TlsStream, TlsConnector}; + +use super::TaskHandle; +use futures_util::ready; +#[cfg(unix)] +use smol::net::unix::UnixStream; +use smol::net::TcpStream; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[inline(always)] +async fn connect_tcp( + addr: &SocketAddr, + tcp_settings: &crate::io::tcp::TcpSettings, +) -> io::Result { + let socket = TcpStream::connect(addr).await?; + let socket_inner: Arc> = socket.into(); + let async_socket_inner = Arc::into_inner(socket_inner).unwrap(); + let std_socket = async_socket_inner.into_inner()?; + let std_socket = crate::io::tcp::stream_with_settings(std_socket, tcp_settings)?; + + std_socket.try_into() +} + +#[cfg(any(feature = "smol-rustls-comp", feature = "smol-native-tls-comp"))] +use crate::connection::TlsConnParams; + +pin_project_lite::pin_project! { + /// Wraps the smol `AsyncRead/AsyncWrite` in order to implement the required the tokio traits + /// for it + pub struct SmolWrapped { #[pin] inner: T } +} + +impl SmolWrapped { + pub(super) fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for SmolWrapped +where + T: smol::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + buf: &[u8], + ) -> std::task::Poll> { + smol::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> std::task::Poll> { + smol::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> std::task::Poll> { + smol::io::AsyncWrite::poll_close(self.project().inner, cx) + } +} + +impl AsyncRead for SmolWrapped +where + T: smol::prelude::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + buf: &mut ReadBuf<'_>, + ) -> std::task::Poll> { + let n = ready!(smol::prelude::AsyncRead::poll_read( + self.project().inner, + cx, + buf.initialize_unfilled() + ))?; + buf.advance(n); + std::task::Poll::Ready(Ok(())) + } +} + +/// Represents an Smol connectable +pub enum Smol { + /// Represents aa TCP connection. + Tcp(SmolWrapped), + /// Represents a TLS encrypted TCP connection. + #[cfg(any(feature = "smol-native-tls-comp", feature = "smol-rustls-comp"))] + TcpTls(SmolWrapped>>), + /// Represents an Unix connection. + #[cfg(unix)] + Unix(SmolWrapped), +} + +impl AsyncWrite for Smol { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Smol::Tcp(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(any(feature = "smol-native-tls-comp", feature = "smol-rustls-comp"))] + Smol::TcpTls(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(unix)] + Smol::Unix(r) => Pin::new(r).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Smol::Tcp(r) => Pin::new(r).poll_flush(cx), + #[cfg(any(feature = "smol-native-tls-comp", feature = "smol-rustls-comp"))] + Smol::TcpTls(r) => Pin::new(r).poll_flush(cx), + #[cfg(unix)] + Smol::Unix(r) => Pin::new(r).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Smol::Tcp(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(any(feature = "smol-native-tls-comp", feature = "smol-rustls-comp"))] + Smol::TcpTls(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(unix)] + Smol::Unix(r) => Pin::new(r).poll_shutdown(cx), + } + } +} + +impl AsyncRead for Smol { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Smol::Tcp(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(any(feature = "smol-native-tls-comp", feature = "smol-rustls-comp"))] + Smol::TcpTls(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(unix)] + Smol::Unix(r) => Pin::new(r).poll_read(cx, buf), + } + } +} + +impl RedisRuntime for Smol { + async fn connect_tcp( + socket_addr: SocketAddr, + tcp_settings: &crate::io::tcp::TcpSettings, + ) -> RedisResult { + Ok(connect_tcp(&socket_addr, tcp_settings) + .await + .map(|con| Self::Tcp(SmolWrapped::new(con)))?) + } + + #[cfg(all(feature = "smol-native-tls-comp", not(feature = "smol-rustls-comp")))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + tcp_settings: &crate::io::tcp::TcpSettings, + ) -> RedisResult { + let tcp_stream = connect_tcp(&socket_addr, tcp_settings).await?; + let tls_connector = if insecure { + TlsConnector::new() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + } else if let Some(params) = tls_params { + TlsConnector::new() + .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames) + } else { + TlsConnector::new() + }; + Ok(tls_connector + .connect(hostname, tcp_stream) + .await + .map(|con| Self::TcpTls(SmolWrapped::new(Box::new(con))))?) + } + + #[cfg(feature = "smol-rustls-comp")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + tcp_settings: &crate::io::tcp::TcpSettings, + ) -> RedisResult { + let tcp_stream = connect_tcp(&socket_addr, tcp_settings).await?; + + let config = create_rustls_config(insecure, tls_params.clone())?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + rustls::pki_types::ServerName::try_from(hostname)?.to_owned(), + tcp_stream, + ) + .await + .map(|con| Self::TcpTls(SmolWrapped::new(Box::new(con))))?) + } + + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult { + Ok(UnixStream::connect(path) + .await + .map(|con| Self::Unix(SmolWrapped::new(con)))?) + } + + fn spawn(f: impl Future + Send + 'static) -> TaskHandle { + TaskHandle::Smol(smol::spawn(f)) + } + + fn boxed(self) -> Pin> { + match self { + Smol::Tcp(x) => Box::pin(x), + #[cfg(any(feature = "smol-native-tls-comp", feature = "smol-rustls-comp"))] + Smol::TcpTls(x) => Box::pin(x), + #[cfg(unix)] + Smol::Unix(x) => Box::pin(x), + } + } +} diff --git a/redis/src/aio/tokio.rs b/redis/src/aio/tokio.rs index 8f88920ed..a7e51a198 100644 --- a/redis/src/aio/tokio.rs +++ b/redis/src/aio/tokio.rs @@ -12,23 +12,20 @@ use tokio::{ net::TcpStream as TcpStreamTokio, }; -#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] use native_tls::TlsConnector; -#[cfg(feature = "tls-rustls")] +#[cfg(feature = "tokio-rustls-comp")] use crate::connection::create_rustls_config; -#[cfg(feature = "tls-rustls")] +#[cfg(feature = "tokio-rustls-comp")] use std::sync::Arc; -#[cfg(feature = "tls-rustls")] +#[cfg(feature = "tokio-rustls-comp")] use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] use tokio_native_tls::TlsStream; -#[cfg(feature = "tokio-rustls-comp")] -use crate::tls::TlsConnParams; - -#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tls-rustls")))] +#[cfg(any(feature = "tokio-rustls-comp", feature = "tokio-native-tls-comp"))] use crate::connection::TlsConnParams; #[cfg(unix)] @@ -119,12 +116,12 @@ impl RedisRuntime for Tokio { .map(Tokio::Tcp)?) } - #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + #[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, insecure: bool, - _: &Option, + params: &Option, tcp_settings: &crate::io::tcp::TcpSettings, ) -> RedisResult { let tls_connector: tokio_native_tls::TlsConnector = if insecure { @@ -133,6 +130,10 @@ impl RedisRuntime for Tokio { .danger_accept_invalid_hostnames(true) .use_sni(false) .build()? + } else if let Some(params) = params { + TlsConnector::builder() + .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames) + .build()? } else { TlsConnector::new()? } @@ -143,7 +144,7 @@ impl RedisRuntime for Tokio { .map(|con| Tokio::TcpTls(Box::new(con)))?) } - #[cfg(feature = "tls-rustls")] + #[cfg(feature = "tokio-rustls-comp")] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, @@ -168,16 +169,10 @@ impl RedisRuntime for Tokio { Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?) } - #[cfg(feature = "tokio-comp")] fn spawn(f: impl Future + Send + 'static) -> TaskHandle { TaskHandle::Tokio(tokio::spawn(f)) } - #[cfg(not(feature = "tokio-comp"))] - fn spawn(_: impl Future + Send + 'static) -> TokioTaskHandle { - unreachable!() - } - fn boxed(self) -> Pin> { match self { Tokio::Tcp(x) => Box::pin(x), diff --git a/redis/src/bb8.rs b/redis/src/bb8.rs new file mode 100644 index 000000000..8a1523e49 --- /dev/null +++ b/redis/src/bb8.rs @@ -0,0 +1,47 @@ +use crate::aio::MultiplexedConnection; +use crate::{Client, Cmd, ErrorKind, RedisError}; + +#[cfg(feature = "cluster-async")] +use crate::{cluster::ClusterClient, cluster_async::ClusterConnection}; + +macro_rules! impl_bb8_manage_connection { + ($client:ty, $connectioin:ty, $get_conn:expr) => { + impl bb8::ManageConnection for $client { + type Connection = $connectioin; + type Error = RedisError; + + async fn connect(&self) -> Result { + $get_conn(self).await + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + let pong: String = Cmd::ping().query_async(conn).await?; + match pong.as_str() { + "PONG" => Ok(()), + _ => Err((ErrorKind::ResponseError, "ping request").into()), + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } + } + }; +} + +impl_bb8_manage_connection!( + Client, + MultiplexedConnection, + Client::get_multiplexed_async_connection +); + +#[cfg(feature = "cluster-async")] +impl_bb8_manage_connection!( + ClusterClient, + ClusterConnection, + ClusterClient::get_async_connection +); + +// TODO: support bb8 for sentinel client which required +// [`crate::sentinel::LockedSentinelClient`] implement async method of +// `get_multiplexed_async_connection`. diff --git a/redis/src/caching/cache_manager.rs b/redis/src/caching/cache_manager.rs index 8f5819a3c..498987a08 100644 --- a/redis/src/caching/cache_manager.rs +++ b/redis/src/caching/cache_manager.rs @@ -20,18 +20,34 @@ pub(crate) enum PrepareCacheResult<'a> { pub(crate) struct CacheManager { lru: Arc, pub(crate) cache_config: CacheConfig, + epoch: usize, } impl CacheManager { pub(crate) fn new(cache_config: CacheConfig) -> Self { + let lru = Arc::new(ShardedLRU::new(cache_config.size)); + let epoch = lru.increase_epoch(); CacheManager { - lru: Arc::new(ShardedLRU::new(cache_config.size)), + lru, cache_config, + epoch, + } + } + + // Clone the CacheManager and increase epoch from LRU, + // this will eventually remove all keys created with previous + // CacheManager's epoch. + #[cfg(feature = "connection-manager")] + pub(crate) fn clone_and_increase_epoch(&self) -> CacheManager { + CacheManager { + lru: self.lru.clone(), + cache_config: self.cache_config, + epoch: self.lru.increase_epoch(), } } pub(crate) fn get<'a>(&self, redis_key: &'a [u8], redis_cmd: &'a [u8]) -> Option { - self.lru.get(redis_key, redis_cmd) + self.lru.get(redis_key, redis_cmd, self.epoch) } pub(crate) fn insert( @@ -51,7 +67,8 @@ impl CacheManager { } _ => client_side_expire_time, }; - self.lru.insert(redis_key, cmd_key, value, expire_time); + self.lru + .insert(redis_key, cmd_key, value, expire_time, self.epoch); } pub(crate) fn statistics(&self) -> CacheStatistics { @@ -358,4 +375,60 @@ mod tests { "Key must be alive, client value must be picked" ); } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_epoch_on_shared_cache_managers() { + let redis_key = b"test_redis_key".as_slice(); + let redis_key_2 = b"test_redis_key_2".as_slice(); + let redis_key_3 = b"test_redis_key_3".as_slice(); + let cmd_key = b"test_cmd_key".as_slice(); + + let shared_cache_manager = CacheManager::new(CacheConfig::new()); + + let cache_manager_1 = shared_cache_manager.clone_and_increase_epoch(); + let cache_manager_2 = shared_cache_manager.clone_and_increase_epoch(); + let cache_manager_3 = shared_cache_manager.clone_and_increase_epoch(); + + let secs_10 = Instant::now().add(Duration::from_secs(10)); + + let do_inserts = |cm1: &CacheManager, cm2: &CacheManager, cm3: &CacheManager| { + cm1.insert(redis_key, cmd_key, Value::Int(1), secs_10, &Value::Int(5)); + cm2.insert(redis_key_2, cmd_key, Value::Int(2), secs_10, &Value::Int(5)); + cm3.insert(redis_key_3, cmd_key, Value::Int(3), secs_10, &Value::Int(5)); + }; + + let do_hit_gets = |cm1: &CacheManager, cm2: &CacheManager, cm3: &CacheManager| { + assert_eq!(cm1.get(redis_key, cmd_key), Some(Value::Int(1))); + assert_eq!(cm2.get(redis_key_2, cmd_key), Some(Value::Int(2))); + assert_eq!(cm3.get(redis_key_3, cmd_key), Some(Value::Int(3))); + }; + + let do_miss_gets = |cm1: &CacheManager, cm2: &CacheManager, cm3: &CacheManager| { + assert_eq!(cm1.get(redis_key, cmd_key), None); + assert_eq!(cm2.get(redis_key_2, cmd_key), None); + assert_eq!(cm3.get(redis_key_3, cmd_key), None); + }; + + do_inserts(&cache_manager_1, &cache_manager_2, &cache_manager_3); + do_hit_gets(&cache_manager_1, &cache_manager_2, &cache_manager_3); + // Different CacheManagers has different epochs so all must return None + do_miss_gets(&cache_manager_2, &cache_manager_3, &cache_manager_1); + + // Check when only one CacheManager has increased the epoch + do_inserts(&cache_manager_1, &cache_manager_2, &cache_manager_3); + do_hit_gets(&cache_manager_1, &cache_manager_2, &cache_manager_3); + + let cache_manager_1 = cache_manager_1.clone_and_increase_epoch(); + assert_eq!(cache_manager_1.get(redis_key, cmd_key), None); + + assert_eq!( + cache_manager_2.get(redis_key_2, cmd_key), + Some(Value::Int(2)) + ); + assert_eq!( + cache_manager_3.get(redis_key_3, cmd_key), + Some(Value::Int(3)) + ); + } } diff --git a/redis/src/caching/sharded_lru.rs b/redis/src/caching/sharded_lru.rs index 050b488df..c0e83d5ed 100644 --- a/redis/src/caching/sharded_lru.rs +++ b/redis/src/caching/sharded_lru.rs @@ -4,6 +4,7 @@ use lru::LruCache; use std::collections::hash_map::DefaultHasher; use std::hash::Hasher; use std::num::NonZeroUsize; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Instant; @@ -18,6 +19,7 @@ pub(crate) struct CacheCmdEntry { /// CacheItem keeps information about a key's expiry time and cached response for each key, command pair. pub(crate) struct CacheItem { expire_time: Instant, + epoch: usize, value_list: Vec, } @@ -26,6 +28,7 @@ type LRUCacheShard = LruCache; pub(crate) struct ShardedLRU { shards: Vec>, pub(crate) statistics: Arc, + last_epoch: AtomicUsize, } impl ShardedLRU { @@ -48,7 +51,11 @@ impl ShardedLRU { shards.push(std::sync::Mutex::new(shard)); } let statistics = Arc::new(Statistics::default()); - ShardedLRU { shards, statistics } + ShardedLRU { + shards, + statistics, + last_epoch: AtomicUsize::new(0), + } } /// get_shard will get MutexGuard for a shard determined by key, if lock is poisoned it'll be recovered. @@ -59,11 +66,20 @@ impl ShardedLRU { lock.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) } - pub(crate) fn get<'a>(&self, redis_key: &'a [u8], redis_cmd: &'a [u8]) -> Option { + pub(crate) fn get<'a>( + &self, + redis_key: &'a [u8], + redis_cmd: &'a [u8], + epoch: usize, + ) -> Option { let mut lru_cache = self.get_shard(redis_key); if let Some(cache_item) = lru_cache.get_mut(redis_key) { - if Instant::now() > cache_item.expire_time { - // Key is expired. + // If one of following conditions are true, cache item is invalid and can't be trusted to use: + // Epoch of client is not same, it means cache item is created by another redis connection. + // Expire time of key has been passed, value could be stale. + let cache_item_is_invalid = + cache_item.epoch != epoch || Instant::now() > cache_item.expire_time; + if cache_item_is_invalid { self.statistics .increase_invalidate(cache_item.value_list.len()); self.statistics.increase_miss(1); @@ -88,33 +104,37 @@ impl ShardedLRU { cmd_key: &[u8], value: Value, expire_time: Instant, + epoch: usize, ) { let mut lru_cache = self.get_shard(redis_key); if let Some(ch) = lru_cache.peek_mut(redis_key) { - for entry in &mut ch.value_list { - if entry.cmd == cmd_key { - entry.value = value; - ch.expire_time = expire_time; - return; + if ch.epoch == epoch { + for entry in &mut ch.value_list { + if entry.cmd == cmd_key { + entry.value = value; + ch.expire_time = expire_time; + return; + } } + ch.value_list.push(CacheCmdEntry { + cmd: cmd_key.to_vec(), + value, + }); + ch.expire_time = expire_time; + return; } - ch.value_list.push(CacheCmdEntry { - cmd: cmd_key.to_vec(), - value, - }); - ch.expire_time = expire_time; - } else { - let _ = lru_cache.push( - redis_key.to_vec(), - CacheItem { - expire_time, - value_list: vec![CacheCmdEntry { - cmd: cmd_key.to_vec(), - value, - }], - }, - ); } + let _ = lru_cache.push( + redis_key.to_vec(), + CacheItem { + expire_time, + value_list: vec![CacheCmdEntry { + cmd: cmd_key.to_vec(), + value, + }], + epoch, + }, + ); } pub(crate) fn invalidate(&self, cache_key: &Vec) { @@ -123,6 +143,10 @@ impl ShardedLRU { .increase_invalidate(cache_holder.value_list.len()); } } + + pub(crate) fn increase_epoch(&self) -> usize { + self.last_epoch.fetch_add(1, Ordering::Relaxed) + } } #[cfg(test)] mod tests { @@ -142,13 +166,14 @@ mod tests { CMD_KEY, Value::Boolean(true), Instant::now().add(Duration::from_secs(10)), + 0, ); assert_eq!( - sharded_lru.get(REDIS_KEY, CMD_KEY), + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), Some(Value::Boolean(true)) ); assert_eq!( - sharded_lru.get(REDIS_KEY, CMD_KEY_2), + sharded_lru.get(REDIS_KEY, CMD_KEY_2, 0), None, "Using different cmd key must result in cache miss" ); @@ -158,15 +183,16 @@ mod tests { CMD_KEY, Value::Boolean(false), Instant::now().add(Duration::from_millis(5)), + 0, ); assert_eq!( - sharded_lru.get(REDIS_KEY, CMD_KEY), + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), Some(Value::Boolean(false)), "Old value must be overwritten" ); std::thread::sleep(Duration::from_millis(6)); assert_eq!( - sharded_lru.get(REDIS_KEY, CMD_KEY), + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), None, "Cache must be expired" ); @@ -181,6 +207,7 @@ mod tests { CMD_KEY, Value::Int(1), Instant::now().add(Duration::from_secs(10)), + 0, ); // Second insert must override expire of the redis key. sharded_lru.insert( @@ -188,21 +215,96 @@ mod tests { CMD_KEY_2, Value::Int(2), Instant::now().add(Duration::from_millis(5)), + 0, ); - assert_eq!(sharded_lru.get(REDIS_KEY, CMD_KEY), Some(Value::Int(1))); - assert_eq!(sharded_lru.get(REDIS_KEY, CMD_KEY_2), Some(Value::Int(2))); + assert_eq!(sharded_lru.get(REDIS_KEY, CMD_KEY, 0), Some(Value::Int(1))); + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY_2, 0), + Some(Value::Int(2)) + ); std::thread::sleep(Duration::from_millis(6)); assert_eq!( - sharded_lru.get(REDIS_KEY, CMD_KEY), + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), None, "Cache must be expired" ); assert_eq!( - sharded_lru.get(REDIS_KEY, CMD_KEY_2), + sharded_lru.get(REDIS_KEY, CMD_KEY_2, 0), None, "Cache must be expired" ); } + + #[test] + fn test_invalidate() { + let sharded_lru = ShardedLRU::new(NonZeroUsize::new(64).unwrap()); + + sharded_lru.insert( + REDIS_KEY, + CMD_KEY, + Value::Boolean(true), + Instant::now().add(Duration::from_secs(10)), + 0, + ); + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), + Some(Value::Boolean(true)) + ); + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY_2, 0), + None, + "Using different cmd key must result in cache miss" + ); + + sharded_lru.invalidate(&REDIS_KEY.to_vec()); + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), + None, + "Cache must be invalidated" + ); + } + + #[test] + fn test_epoch_change() { + let sharded_lru = ShardedLRU::new(NonZeroUsize::new(64).unwrap()); + + let another_key = "foobar"; + + sharded_lru.insert( + REDIS_KEY, + CMD_KEY, + Value::Boolean(true), + Instant::now().add(Duration::from_secs(10)), + 0, + ); + sharded_lru.insert( + another_key.as_bytes(), + CMD_KEY, + Value::Boolean(true), + Instant::now().add(Duration::from_secs(10)), + 0, + ); + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY, 0), + Some(Value::Boolean(true)) + ); + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY_2, 0), + None, + "Using different cmd key must result in cache miss" + ); + + assert_eq!( + sharded_lru.get(REDIS_KEY, CMD_KEY, 1), + None, + "Cache must be invalidated" + ); + assert_eq!( + sharded_lru.get(another_key.as_bytes(), CMD_KEY, 1), + None, + "Cache must be invalidated" + ); + } } diff --git a/redis/src/client.rs b/redis/src/client.rs index cb0b384e4..4c883dd70 100644 --- a/redis/src/client.rs +++ b/redis/src/client.rs @@ -1,9 +1,9 @@ use std::time::Duration; #[cfg(feature = "aio")] -use crate::aio::AsyncPushSender; +use crate::aio::{AsyncPushSender, DefaultAsyncDNSResolver}; #[cfg(feature = "aio")] -use crate::io::tcp::TcpSettings; +use crate::io::{tcp::TcpSettings, AsyncDNSResolver}; use crate::{ connection::{connect, Connection, ConnectionInfo, ConnectionLike, IntoConnectionInfo}, types::{RedisResult, Value}, @@ -16,6 +16,8 @@ use crate::tls::{inner_build_with_tls, TlsCertificates}; #[cfg(feature = "cache-aio")] use crate::caching::CacheConfig; +#[cfg(all(feature = "cache-aio", feature = "connection-manager"))] +use crate::caching::CacheManager; /// The client type. #[derive(Debug, Clone)] @@ -161,6 +163,14 @@ impl Client { } } +#[cfg(feature = "cache-aio")] +#[derive(Clone)] +pub(crate) enum Cache { + Config(CacheConfig), + #[cfg(feature = "connection-manager")] + Manager(CacheManager), +} + /// Options for creation of async connection #[cfg(feature = "aio")] #[derive(Clone, Default)] @@ -171,8 +181,9 @@ pub struct AsyncConnectionConfig { pub(crate) connection_timeout: Option, pub(crate) push_sender: Option>, #[cfg(feature = "cache-aio")] - pub(crate) cache_config: Option, + pub(crate) cache: Option, pub(crate) tcp_settings: TcpSettings, + pub(crate) dns_resolver: Option>, } #[cfg(feature = "aio")] @@ -234,7 +245,13 @@ impl AsyncConnectionConfig { /// Sets cache config for MultiplexedConnection, check CacheConfig for more details. #[cfg(feature = "cache-aio")] pub fn set_cache_config(mut self, cache_config: CacheConfig) -> Self { - self.cache_config = Some(cache_config); + self.cache = Some(Cache::Config(cache_config)); + self + } + + #[cfg(all(feature = "cache-aio", feature = "connection-manager"))] + pub(crate) fn set_cache_manager(mut self, cache_manager: CacheManager) -> Self { + self.cache = Some(Cache::Manager(cache_manager)); self } @@ -245,6 +262,21 @@ impl AsyncConnectionConfig { ..self } } + + /// Set the DNS resolver for the underlying TCP connection. + /// + /// The parameter resolver must implement the [`crate::io::AsyncDNSResolver`] trait. + pub fn set_dns_resolver(self, dns_resolver: impl AsyncDNSResolver) -> Self { + self.set_dns_resolver_internal(std::sync::Arc::new(dns_resolver)) + } + + pub(super) fn set_dns_resolver_internal( + mut self, + dns_resolver: std::sync::Arc, + ) -> Self { + self.dns_resolver = Some(dns_resolver); + self + } } /// To enable async support you need to chose one of the supported runtimes and active its @@ -253,57 +285,8 @@ impl AsyncConnectionConfig { #[cfg_attr(docsrs, doc(cfg(feature = "aio")))] impl Client { /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[deprecated( - note = "aio::Connection is deprecated. Use client::get_multiplexed_async_connection instead." - )] - #[allow(deprecated)] - pub async fn get_async_connection(&self) -> RedisResult { - let con = self - .get_simple_async_connection_dynamically(&TcpSettings::default()) - .await?; - - crate::aio::Connection::new(&self.connection_info.redis, con).await - } - - /// Returns an async connection from the client. - #[cfg(feature = "tokio-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] - #[deprecated( - note = "aio::Connection is deprecated. Use client::get_multiplexed_async_connection instead." - )] - #[allow(deprecated)] - pub async fn get_tokio_connection(&self) -> RedisResult { - use crate::aio::RedisRuntime; - Ok( - crate::aio::connect::(&self.connection_info) - .await? - .map(RedisRuntime::boxed), - ) - } - - /// Returns an async connection from the client. - #[cfg(feature = "async-std-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] - #[deprecated( - note = "aio::Connection is deprecated. Use client::get_multiplexed_async_std_connection instead." - )] - #[allow(deprecated)] - pub async fn get_async_std_connection(&self) -> RedisResult { - use crate::aio::RedisRuntime; - Ok( - crate::aio::connect::(&self.connection_info) - .await? - .map(RedisRuntime::boxed), - ) - } - - /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) - )] + #[cfg(feature = "aio")] + #[cfg_attr(docsrs, doc(cfg(feature = "aio")))] pub async fn get_multiplexed_async_connection( &self, ) -> RedisResult { @@ -312,11 +295,8 @@ impl Client { } /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) - )] + #[cfg(feature = "aio")] + #[cfg_attr(docsrs, doc(cfg(feature = "aio")))] #[deprecated(note = "Use `get_multiplexed_async_connection_with_config` instead")] pub async fn get_multiplexed_async_connection_with_timeouts( &self, @@ -332,56 +312,31 @@ impl Client { } /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) - )] + #[cfg(feature = "aio")] + #[cfg_attr(docsrs, doc(cfg(feature = "aio")))] pub async fn get_multiplexed_async_connection_with_config( &self, config: &AsyncConnectionConfig, ) -> RedisResult { - let result = match Runtime::locate() { + match Runtime::locate() { #[cfg(feature = "tokio-comp")] - rt @ Runtime::Tokio => { - if let Some(connection_timeout) = config.connection_timeout { - rt.timeout( - connection_timeout, - self.get_multiplexed_async_connection_inner::( - config, - ), - ) - .await - } else { - Ok(self - .get_multiplexed_async_connection_inner::(config) - .await) - } - } - #[cfg(feature = "async-std-comp")] - rt @ Runtime::AsyncStd => { - if let Some(connection_timeout) = config.connection_timeout { - rt.timeout( - connection_timeout, - self.get_multiplexed_async_connection_inner::( - config, - ), - ) - .await - } else { - Ok(self - .get_multiplexed_async_connection_inner::( - config, - ) - .await) - } - } - }; + rt @ Runtime::Tokio => self + .get_multiplexed_async_connection_inner_with_timeout::( + config, rt, + ) + .await, - match result { - Ok(Ok(connection)) => Ok(connection), - Ok(Err(e)) => Err(e), - Err(elapsed) => Err(elapsed.into()), + #[cfg(feature = "async-std-comp")] + rt @ Runtime::AsyncStd => self.get_multiplexed_async_connection_inner_with_timeout::< + crate::aio::async_std::AsyncStd, + >(config, rt) + .await, + + #[cfg(feature = "smol-comp")] + rt @ Runtime::Smol => self.get_multiplexed_async_connection_inner_with_timeout::< + crate::aio::smol::Smol, + >(config, rt) + .await, } } @@ -770,6 +725,33 @@ impl Client { crate::aio::ConnectionManager::new_with_config(self.clone(), config).await } + async fn get_multiplexed_async_connection_inner_with_timeout( + &self, + config: &AsyncConnectionConfig, + rt: Runtime, + ) -> RedisResult + where + T: crate::aio::RedisRuntime, + { + let result = if let Some(connection_timeout) = config.connection_timeout { + rt.timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::(config), + ) + .await + } else { + Ok(self + .get_multiplexed_async_connection_inner::(config) + .await) + }; + + match result { + Ok(Ok(connection)) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + } + async fn get_multiplexed_async_connection_inner( &self, config: &AsyncConnectionConfig, @@ -795,8 +777,12 @@ impl Client { where T: crate::aio::RedisRuntime, { + let resolver = config + .dns_resolver + .as_deref() + .unwrap_or(&DefaultAsyncDNSResolver); let con = self - .get_simple_async_connection::(&config.tcp_settings) + .get_simple_async_connection::(resolver, &config.tcp_settings) .await?; crate::aio::MultiplexedConnection::new_with_config( &self.connection_info.redis, @@ -808,32 +794,49 @@ impl Client { async fn get_simple_async_connection_dynamically( &self, + dns_resolver: &dyn AsyncDNSResolver, tcp_settings: &TcpSettings, ) -> RedisResult>> { match Runtime::locate() { #[cfg(feature = "tokio-comp")] Runtime::Tokio => { - self.get_simple_async_connection::(tcp_settings) - .await + self.get_simple_async_connection::( + dns_resolver, + tcp_settings, + ) + .await } #[cfg(feature = "async-std-comp")] Runtime::AsyncStd => { - self.get_simple_async_connection::(tcp_settings) - .await + self.get_simple_async_connection::( + dns_resolver, + tcp_settings, + ) + .await + } + + #[cfg(feature = "smol-comp")] + Runtime::Smol => { + self.get_simple_async_connection::( + dns_resolver, + tcp_settings, + ) + .await } } } async fn get_simple_async_connection( &self, + dns_resolver: &dyn AsyncDNSResolver, tcp_settings: &TcpSettings, ) -> RedisResult>> where T: crate::aio::RedisRuntime, { Ok( - crate::aio::connect_simple::(&self.connection_info, tcp_settings) + crate::aio::connect_simple::(&self.connection_info, dns_resolver, tcp_settings) .await? .boxed(), ) @@ -845,24 +848,29 @@ impl Client { } /// Returns an async receiver for pub-sub messages. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "aio")] // TODO - do we want to type-erase pubsub using a trait, to allow us to replace it with a different implementation later? pub async fn get_async_pubsub(&self) -> RedisResult { let connection = self - .get_simple_async_connection_dynamically(&TcpSettings::default()) + .get_simple_async_connection_dynamically( + &DefaultAsyncDNSResolver, + &TcpSettings::default(), + ) .await?; crate::aio::PubSub::new(&self.connection_info.redis, connection).await } /// Returns an async receiver for monitor messages. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - // TODO - do we want to type-erase monitor using a trait, to allow us to replace it with a different implementation later? + #[cfg(feature = "aio")] pub async fn get_async_monitor(&self) -> RedisResult { - #[allow(deprecated)] - self.get_async_connection() - .await - .map(|connection| connection.into_monitor()) + let connection = self + .get_simple_async_connection_dynamically( + &DefaultAsyncDNSResolver, + &TcpSettings::default(), + ) + .await?; + crate::aio::Monitor::new(&self.connection_info.redis, connection).await } } diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 1bb1da03e..be022fa1c 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -90,10 +90,6 @@ use rand::{rng, seq::IteratorRandom, Rng}; pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; -#[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; - -#[cfg(not(feature = "tls-rustls"))] use crate::connection::TlsConnParams; #[derive(Clone)] @@ -235,6 +231,8 @@ pub struct ClusterConfig { pub(crate) response_timeout: Option, #[cfg(feature = "cluster-async")] pub(crate) async_push_sender: Option>, + #[cfg(feature = "cluster-async")] + pub(crate) async_dns_resolver: Option>, } impl ClusterConfig { @@ -284,6 +282,15 @@ impl ClusterConfig { self.async_push_sender = Some(std::sync::Arc::new(sender)); self } + + /// Set asynchronous DNS resolver for the underlying TCP connection. + /// + /// The parameter resolver must implement the [`crate::io::AsyncDNSResolver`] trait. + #[cfg(feature = "cluster-async")] + pub fn set_dns_resolver(mut self, resolver: impl crate::io::AsyncDNSResolver) -> Self { + self.async_dns_resolver = Some(std::sync::Arc::new(resolver)); + self + } } /// This represents a Redis Cluster connection. @@ -312,7 +319,7 @@ where connections: RefCell::new(HashMap::new()), slots: RefCell::new(SlotMap::new(cluster_params.read_from_replicas)), auto_reconnect: RefCell::new(true), - read_timeout: RefCell::new(Some(cluster_params.response_timeout)), + read_timeout: RefCell::new(cluster_params.response_timeout), write_timeout: RefCell::new(None), initial_nodes: initial_nodes.to_vec(), cluster_params, diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 12e51cb50..271da64c9 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -126,7 +126,7 @@ use futures_util::{ ready, stream::{self, Stream, StreamExt}, }; -use log::{trace, warn}; +use log::{debug, trace, warn}; use rand::{rng, seq::IteratorRandom}; use request::{CmdArg, PendingRequest, Request, RequestState, Retry}; use routing::{route_for_pipeline, InternalRoutingInfo, InternalSingleNodeRouting}; @@ -135,6 +135,8 @@ use tokio::sync::{mpsc, oneshot, RwLock}; struct ClientSideState { protocol: ProtocolVersion, _task_handle: HandleContainer, + response_timeout: Option, + runtime: Runtime, } /// This represents an async Redis Cluster connection. @@ -156,6 +158,8 @@ where cluster_params: ClusterParams, ) -> RedisResult> { let protocol = cluster_params.protocol.unwrap_or_default(); + let response_timeout = cluster_params.response_timeout; + let runtime = Runtime::locate(); ClusterConnInner::new(initial_nodes, cluster_params) .await .map(|inner| { @@ -166,13 +170,15 @@ where .forward(inner) .await; }; - let _task_handle = HandleContainer::new(Runtime::locate().spawn(stream)); + let _task_handle = HandleContainer::new(runtime.spawn(stream)); ClusterConnection { sender, state: Arc::new(ClientSideState { protocol, _task_handle, + response_timeout, + runtime, }), } }) @@ -182,33 +188,41 @@ where pub async fn route_command(&mut self, cmd: &Cmd, routing: RoutingInfo) -> RedisResult { trace!("send_packed_command"); let (sender, receiver) = oneshot::channel(); - self.sender - .send(Message { - cmd: CmdArg::Cmd { - cmd: Arc::new(cmd.clone()), // TODO Remove this clone? - routing: routing.into(), - }, - sender, - }) - .await - .map_err(|_| { - RedisError::from(io::Error::new( - io::ErrorKind::BrokenPipe, - "redis_cluster: Unable to send command", - )) - })?; - receiver - .await - .unwrap_or_else(|_| { - Err(RedisError::from(io::Error::new( - io::ErrorKind::BrokenPipe, - "redis_cluster: Unable to receive command", - ))) - }) - .map(|response| match response { - Response::Single(value) => value, - Response::Multiple(_) => unreachable!(), - }) + let request = async { + self.sender + .send(Message { + cmd: CmdArg::Cmd { + cmd: Arc::new(cmd.clone()), // TODO Remove this clone? + routing: routing.into(), + }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + }) + }; + + match self.state.response_timeout { + Some(duration) => self.state.runtime.timeout(duration, request).await?, + None => request.await, + } } /// Send commands in `pipeline` to the given `route`. If `route` is [None], it will be sent to a random node. @@ -220,36 +234,43 @@ where route: SingleNodeRoutingInfo, ) -> RedisResult> { let (sender, receiver) = oneshot::channel(); - self.sender - .send(Message { - cmd: CmdArg::Pipeline { - pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone? - offset, - count, - route: route.into(), - }, - sender, - }) - .await - .map_err(|_| closed_connection_error())?; - receiver - .await - .unwrap_or_else(|_| Err(closed_connection_error())) - .map(|response| match response { - Response::Multiple(values) => values, - Response::Single(_) => unreachable!(), - }) + let request = async { + self.sender + .send(Message { + cmd: CmdArg::Pipeline { + pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone? + offset, + count, + route: route.into(), + }, + sender, + }) + .await + .map_err(|_| closed_connection_error())?; + receiver + .await + .unwrap_or_else(|_| Err(closed_connection_error())) + .map(|response| match response { + Response::Multiple(values) => values, + Response::Single(_) => unreachable!(), + }) + }; + + match self.state.response_timeout { + Some(duration) => self.state.runtime.timeout(duration, request).await?, + None => request.await, + } } - /// Subscribes to a new channel(s). + /// Subscribes to a new channel(s). /// /// Updates from the sender will be sent on the push sender that was passed to the manager. /// If the manager was configured without a push sender, the connection won't be able to pass messages back to the user. /// /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise. /// It should be noted that the subscription will be automatically resubscribed after disconnections, so the user might - /// receive additional pushes with [crate::PushKind::Subcribe], later after the subscription completed. + /// receive additional pushes with [crate::PushKind::Subscribe], later after the subscription completed. pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> { check_resp3!(self.state.protocol); let mut cmd = cmd("SUBSCRIBE"); @@ -276,7 +297,7 @@ where /// /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise. /// It should be noted that the subscription will be automatically resubscribed after disconnections, so the user might - /// receive additional pushes with [crate::PushKind::PSubcribe], later after the subscription completed. + /// receive additional pushes with [crate::PushKind::PSubscribe], later after the subscription completed. pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> { check_resp3!(self.state.protocol); let mut cmd = cmd("PSUBSCRIBE"); @@ -303,7 +324,7 @@ where /// /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise. /// It should be noted that the subscription will be automatically resubscribed after disconnections, so the user might - /// receive additional pushes with [crate::PushKind::SSubcribe], later after the subscription completed. + /// receive additional pushes with [crate::PushKind::SSubscribe], later after the subscription completed. pub async fn ssubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> { check_resp3!(self.state.protocol); let mut cmd = cmd("SSUBSCRIBE"); @@ -324,8 +345,7 @@ where } } -type ConnectionFuture = future::Shared>; -type ConnectionMap = HashMap>; +type ConnectionMap = HashMap; /// This is the internal representation of an async Redis Cluster connection. It stores the /// underlying connections maintained for each node in the cluster, as well @@ -444,9 +464,9 @@ where let addr = info.addr.to_string(); let result = connect_and_check(&addr, params).await; match result { - Ok(conn) => Ok((addr, async { conn }.boxed().shared())), + Ok(conn) => Ok((addr, conn)), Err(e) => { - trace!("Failed to connect to initial node: {:?}", e); + debug!("Failed to connect to initial node: {:?}", e); Err(e) } } @@ -515,6 +535,7 @@ where } fn reconnect_to_initial_nodes(&mut self) -> impl Future { + debug!("Received request to reconnect to initial nodes"); let inner = self.inner.clone(); async move { let connection_map = @@ -543,85 +564,76 @@ where let inner = self.inner.clone(); async move { let mut write_guard = inner.conn_lock.write().await; - let mut connections = stream::iter(addrs) - .fold( - mem::take(&mut write_guard.0), - |mut connections, addr| async { - let conn = Self::get_or_create_conn( - &addr, - connections.remove(&addr), - &inner.cluster_params, - ) - .await; - if let Ok(conn) = conn { - connections.insert(addr, async { conn }.boxed().shared()); - } - connections - }, - ) - .await; - write_guard.0 = mem::take(&mut connections); + + Self::refresh_connections_locked(&inner, &mut write_guard.0, addrs).await; } } // Query a node to discover slot-> master mappings. async fn refresh_slots(inner: Core) -> RedisResult<()> { let mut write_guard = inner.conn_lock.write().await; - let mut connections = mem::take(&mut write_guard.0); - let slots = &mut write_guard.1; + let (connections, slots) = &mut *write_guard; + let mut result = Ok(()); - for (addr, conn) in connections.iter_mut() { - let mut conn = conn.clone().await; - let value = match conn - .req_packed_command(&slot_cmd()) - .await - .and_then(|value| value.extract_error()) - { - Ok(value) => value, - Err(err) => { - result = Err(err); - continue; - } - }; - match parse_slots( - value, - inner.cluster_params.tls, - addr.rsplit_once(':').unwrap().0, - ) - .and_then(|v: Vec| Self::build_slot_map(slots, v)) - { - Ok(_) => { - result = Ok(()); - break; - } - Err(err) => result = Err(err), + for (addr, conn) in &mut *connections { + result = async { + let value = conn + .req_packed_command(&slot_cmd()) + .await + .and_then(|value| value.extract_error())?; + let v: Vec = parse_slots( + value, + inner.cluster_params.tls, + addr.rsplit_once(':').unwrap().0, + )?; + Self::build_slot_map(slots, v) + } + .await; + if result.is_ok() { + break; } } result?; - let mut nodes = write_guard.1.values().flatten().collect::>(); + let mut nodes = slots.values().flatten().cloned().collect::>(); nodes.sort_unstable(); nodes.dedup(); + Self::refresh_connections_locked(&inner, connections, nodes).await; + + Ok(()) + } + + async fn refresh_connections_locked( + inner: &Core, + connections: &mut ConnectionMap, + nodes: Vec, + ) { let nodes_len = nodes.len(); - let addresses_and_connections_iter = nodes - .into_iter() - .map(|addr| (addr, connections.remove(addr))); - write_guard.0 = stream::iter(addresses_and_connections_iter) + let addresses_and_connections_iter = nodes.into_iter().map(|addr| { + let value = connections.remove(&addr); + (addr, value) + }); + + let inner = &inner; + *connections = stream::iter(addresses_and_connections_iter) + .map(|(addr, connection)| async move { + ( + addr.clone(), + Self::get_or_create_conn(&addr, connection, &inner.cluster_params).await, + ) + }) + .buffer_unordered(nodes_len.max(8)) .fold( HashMap::with_capacity(nodes_len), - |mut connections, (addr, connection)| async { - let conn = - Self::get_or_create_conn(addr, connection, &inner.cluster_params).await; - if let Ok(conn) = conn { - connections.insert(addr.to_string(), async { conn }.boxed().shared()); + |mut connections, (addr, result)| async move { + if let Ok(conn) = result { + connections.insert(addr, conn); } connections }, ) .await; - - Ok(()) } fn build_slot_map(slot_map: &mut SlotMap, slots_data: Vec) -> RedisResult<()> { @@ -889,7 +901,7 @@ where .slot_addr_for_route(&route) .map(|addr| addr.to_string()), InternalSingleNodeRouting::Connection { identifier, conn } => { - return Ok((identifier, conn.await)); + return Ok((identifier, conn)); } InternalSingleNodeRouting::Redirect { redirect, .. } => { drop(read_guard); @@ -898,7 +910,7 @@ where } InternalSingleNodeRouting::ByAddress(address) => { if let Some(conn) = read_guard.0.get(&address).cloned() { - return Ok((address, conn.await)); + return Ok((address, conn)); } else { return Err(( ErrorKind::ClientError, @@ -916,7 +928,7 @@ where drop(read_guard); let addr_conn_option = match conn { - Some((addr, Some(conn))) => Some((addr, conn.await)), + Some((addr, Some(conn))) => Some((addr, conn)), Some((addr, None)) => connect_check_and_add(core.clone(), addr.clone()) .await .ok() @@ -928,11 +940,9 @@ where Some(tuple) => tuple, None => { let read_guard = core.conn_lock.read().await; - if let Some((random_addr, random_conn_future)) = - get_random_connection(&read_guard.0) - { + if let Some((random_addr, random_conn)) = get_random_connection(&read_guard.0) { drop(read_guard); - (random_addr, random_conn_future.await) + (random_addr, random_conn) } else { return Err( (ErrorKind::ClusterConnectionNotFound, "No connections found").into(), @@ -957,7 +967,7 @@ where let conn = read_guard.0.get(&addr).cloned(); drop(read_guard); let mut conn = match conn { - Some(conn) => conn.await, + Some(conn) => conn, None => connect_check_and_add(core.clone(), addr.clone()).await?, }; if asking { @@ -1089,11 +1099,10 @@ where async fn get_or_create_conn( addr: &str, - conn_option: Option>, + conn_option: Option, params: &ClusterParams, ) -> RedisResult { - if let Some(conn) = conn_option { - let mut conn = conn.await; + if let Some(mut conn) = conn_option { match check_connection(&mut conn).await { Ok(_) => Ok(conn), Err(_) => connect_and_check(addr, params.clone()).await, @@ -1266,23 +1275,6 @@ where pub trait Connect: Sized { /// Connect to a node, returning handle for command execution. fn connect_with_config<'a, T>(info: T, config: AsyncConnectionConfig) -> RedisFuture<'a, Self> - where - T: IntoConnectionInfo + Send + 'a, - { - // default implementation, for backwards compatibility - Self::connect( - info, - config.response_timeout.unwrap_or(Duration::MAX), - config.connection_timeout.unwrap_or(Duration::MAX), - ) - } - - /// Connect to a node, returning handle for command execution. - fn connect<'a, T>( - info: T, - response_timeout: Duration, - connection_timeout: Duration, - ) -> RedisFuture<'a, Self> where T: IntoConnectionInfo + Send + 'a; } @@ -1301,27 +1293,6 @@ impl Connect for MultiplexedConnection { } .boxed() } - - fn connect<'a, T>( - info: T, - response_timeout: Duration, - connection_timeout: Duration, - ) -> RedisFuture<'a, MultiplexedConnection> - where - T: IntoConnectionInfo + Send + 'a, - { - async move { - let connection_info = info.into_connection_info()?; - let client = crate::Client::open(connection_info)?; - let config = crate::AsyncConnectionConfig::new() - .set_connection_timeout(connection_timeout) - .set_response_timeout(response_timeout); - client - .get_multiplexed_async_connection_with_config(&config) - .await - } - .boxed() - } } async fn connect_check_and_add(core: Core, addr: String) -> RedisResult @@ -1331,11 +1302,7 @@ where match connect_and_check::(&addr, core.cluster_params.clone()).await { Ok(conn) => { let conn_clone = conn.clone(); - core.conn_lock - .write() - .await - .0 - .insert(addr, async { conn_clone }.boxed().shared()); + core.conn_lock.write().await.0.insert(addr, conn_clone); Ok(conn) } Err(err) => Err(err), @@ -1351,15 +1318,27 @@ where let response_timeout = params.response_timeout; let push_sender = params.async_push_sender.clone(); let tcp_settings = params.tcp_settings.clone(); + let dns_resolver = params.async_dns_resolver.clone(); let info = get_connection_info(node, params)?; let mut config = AsyncConnectionConfig::default() .set_connection_timeout(connection_timeout) - .set_response_timeout(response_timeout) .set_tcp_settings(tcp_settings); + if let Some(response_timeout) = response_timeout { + config = config.set_response_timeout(response_timeout); + }; if let Some(push_sender) = push_sender { config = config.set_push_sender_internal(push_sender); } - let mut conn: C = C::connect_with_config(info, config).await?; + if let Some(resolver) = dns_resolver { + config = config.set_dns_resolver_internal(resolver.clone()); + } + let mut conn = match C::connect_with_config(info, config).await { + Ok(conn) => conn, + Err(err) => { + warn!("Failed to connect to node: {:?}, due to: {:?}", node, err); + return Err(err); + } + }; let check = if read_from_replicas { // If READONLY is sent to primary nodes, it will have no effect @@ -1382,7 +1361,7 @@ where Ok(()) } -fn get_random_connection(connections: &ConnectionMap) -> Option<(String, ConnectionFuture)> +fn get_random_connection(connections: &ConnectionMap) -> Option<(String, C)> where C: Clone, { diff --git a/redis/src/cluster_async/routing.rs b/redis/src/cluster_async/routing.rs index 506d79b82..3389d8e4e 100644 --- a/redis/src/cluster_async/routing.rs +++ b/redis/src/cluster_async/routing.rs @@ -6,8 +6,6 @@ use crate::{ Cmd, ErrorKind, RedisResult, }; -use super::ConnectionFuture; - #[derive(Clone)] pub(super) enum InternalRoutingInfo { SingleNode(InternalSingleNodeRouting), @@ -40,7 +38,7 @@ pub(super) enum InternalSingleNodeRouting { ByAddress(String), Connection { identifier: String, - conn: ConnectionFuture, + conn: C, }, Redirect { redirect: Redirect, @@ -116,7 +114,7 @@ mod pipeline_routing_tests { let mut pipeline = crate::Pipeline::new(); pipeline - .add_command(cmd("FLUSHALL")) // route to all masters + .flushall() // route to all masters .get("foo") // route to slot 12182 .add_command(cmd("EVAL")); // route randomly @@ -131,7 +129,7 @@ mod pipeline_routing_tests { let mut pipeline = crate::Pipeline::new(); pipeline - .add_command(cmd("FLUSHALL")) // route to all masters + .flushall() // route to all masters .add_command(cmd("EVAL")); // route randomly assert_eq!(route_for_pipeline(&pipeline), Ok(None)); @@ -143,7 +141,7 @@ mod pipeline_routing_tests { pipeline .get("foo") // route to replica of slot 12182 - .add_command(cmd("FLUSHALL")) // route to all masters + .flushall() // route to all masters .add_command(cmd("EVAL"))// route randomly .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command .set("foo", "bar"); // route to primary of slot 12182 @@ -159,7 +157,7 @@ mod pipeline_routing_tests { let mut pipeline = crate::Pipeline::new(); pipeline - .add_command(cmd("FLUSHALL")) // route to all masters + .flushall() // route to all masters .set("baz", "bar") // route to slot 4813 .get("foo"); // route to slot 12182 diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index 6c2013565..12a46d62b 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -1,7 +1,8 @@ #[cfg(feature = "cluster-async")] use crate::aio::AsyncPushSender; use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo}; -use crate::io::tcp::TcpSettings; +#[cfg(feature = "cluster-async")] +use crate::io::{tcp::TcpSettings, AsyncDNSResolver}; use crate::types::{ErrorKind, ProtocolVersion, RedisError, RedisResult}; use crate::{cluster, cluster::TlsMode}; use rand::Rng; @@ -9,10 +10,6 @@ use rand::Rng; use std::sync::Arc; use std::time::Duration; -#[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; - -#[cfg(not(feature = "tls-rustls"))] use crate::connection::TlsConnParams; #[cfg(feature = "cluster-async")] @@ -32,13 +29,18 @@ struct BuilderParams { tls: Option, #[cfg(feature = "tls-rustls")] certs: Option, + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + danger_accept_invalid_hostnames: bool, retries_configuration: RetryParams, connection_timeout: Option, response_timeout: Option, protocol: Option, #[cfg(feature = "cluster-async")] async_push_sender: Option>, + #[cfg(feature = "cluster-async")] pub(crate) tcp_settings: TcpSettings, + #[cfg(feature = "cluster-async")] + async_dns_resolver: Option>, } #[derive(Clone)] @@ -91,17 +93,20 @@ pub(crate) struct ClusterParams { pub(crate) retry_params: RetryParams, pub(crate) tls_params: Option, pub(crate) connection_timeout: Duration, - pub(crate) response_timeout: Duration, + pub(crate) response_timeout: Option, pub(crate) protocol: Option, #[cfg(feature = "cluster-async")] pub(crate) async_push_sender: Option>, + #[cfg(feature = "cluster-async")] pub(crate) tcp_settings: TcpSettings, + #[cfg(feature = "cluster-async")] + pub(crate) async_dns_resolver: Option>, } impl ClusterParams { fn from(value: BuilderParams) -> RedisResult { #[cfg(not(feature = "tls-rustls"))] - let tls_params = None; + let tls_params: Option = None; #[cfg(feature = "tls-rustls")] let tls_params = { @@ -110,6 +115,21 @@ impl ClusterParams { retrieved_tls_params.transpose()? }; + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + let tls_params = if value.danger_accept_invalid_hostnames { + let mut tls_params = tls_params.unwrap_or(TlsConnParams { + #[cfg(feature = "tls-rustls")] + client_tls_params: None, + #[cfg(feature = "tls-rustls")] + root_cert_store: None, + danger_accept_invalid_hostnames: false, + }); + tls_params.danger_accept_invalid_hostnames = true; + Some(tls_params) + } else { + tls_params + }; + Ok(Self { password: value.password, username: value.username, @@ -118,11 +138,14 @@ impl ClusterParams { retry_params: value.retries_configuration, tls_params, connection_timeout: value.connection_timeout.unwrap_or(Duration::from_secs(1)), - response_timeout: value.response_timeout.unwrap_or(Duration::MAX), + response_timeout: value.response_timeout, protocol: value.protocol, #[cfg(feature = "cluster-async")] async_push_sender: value.async_push_sender, + #[cfg(feature = "cluster-async")] tcp_settings: value.tcp_settings, + #[cfg(feature = "cluster-async")] + async_dns_resolver: value.async_dns_resolver, }) } @@ -130,13 +153,18 @@ impl ClusterParams { if let Some(connection_timeout) = config.connection_timeout { self.connection_timeout = connection_timeout; } - if let Some(response_timeout) = config.response_timeout { - self.response_timeout = response_timeout; - } + self.response_timeout = config.response_timeout; + #[cfg(feature = "cluster-async")] if let Some(async_push_sender) = config.async_push_sender { self.async_push_sender = Some(async_push_sender); } + + #[cfg(feature = "cluster-async")] + if let Some(async_dns_resolver) = config.async_dns_resolver { + self.async_dns_resolver = Some(async_dns_resolver); + } + self } } @@ -306,6 +334,25 @@ impl ClusterClientBuilder { self } + /// Configure hostname verification when connecting with TLS. + /// + /// If `insecure` is true, this **disables** hostname verification, while + /// leaving other aspects of certificate checking enabled. This mode is + /// similar to what `redis-cli` does: TLS connections do check certificates, + /// but hostname errors are ignored. + /// + /// # Warning + /// + /// You should think very carefully before you use this method. If hostname + /// verification is not used, any valid certificate for any site will be + /// trusted for use from any other. This introduces a significant + /// vulnerability to man-in-the-middle attacks. + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + pub fn danger_accept_invalid_hostnames(mut self, insecure: bool) -> ClusterClientBuilder { + self.builder_params.danger_accept_invalid_hostnames = insecure; + self + } + /// Sets raw TLS certificates for the new ClusterClient. /// /// When set, enforces the connection must be TLS secured. @@ -417,6 +464,15 @@ impl ClusterClientBuilder { self.builder_params.tcp_settings = tcp_settings; self } + + /// Set asynchronous DNS resolver for the underlying TCP connection. + /// + /// The parameter resolver must implement the [`crate::io::AsyncDNSResolver`] trait. + #[cfg(feature = "cluster-async")] + pub fn async_dns_resolver(mut self, resolver: impl AsyncDNSResolver) -> ClusterClientBuilder { + self.builder_params.async_dns_resolver = Some(Arc::new(resolver)); + self + } } /// A Redis Cluster client, used to create connections. diff --git a/redis/src/cluster_pipeline.rs b/redis/src/cluster_pipeline.rs index ab2389d85..a904e5ce6 100644 --- a/redis/src/cluster_pipeline.rs +++ b/redis/src/cluster_pipeline.rs @@ -18,13 +18,13 @@ fn is_illegal_cmd(cmd: &str) -> bool { // All commands that start with "CONFIG" "CONFIG" | "CONFIG GET" | "CONFIG RESETSTAT" | "CONFIG REWRITE" | "CONFIG SET" | "DBSIZE" | - "ECHO" | "EVALSHA" | + "ECHO" | "FLUSHALL" | "FLUSHDB" | "INFO" | "KEYS" | "LASTSAVE" | "MGET" | "MOVE" | "MSET" | "MSETNX" | - "PFMERGE" | "PFCOUNT" | "PING" | "PUBLISH" | + "PING" | "PUBLISH" | "RANDOMKEY" | "RENAME" | "RENAMENX" | "RPOPLPUSH" | "SAVE" | "SCAN" | // All commands that start with "SCRIPT" @@ -62,7 +62,7 @@ pub struct ClusterPipeline { /// KEYS /// LASTSAVE /// MGET, MOVE, MSET, MSETNX -/// PFMERGE, PFCOUNT, PING, PUBLISH +/// PING, PUBLISH /// RANDOMKEY, RENAME, RENAMENX, RPOPLPUSH /// SAVE, SCAN, SCRIPT EXISTS, SCRIPT FLUSH, SCRIPT KILL, SCRIPT LOAD, SDIFF, SDIFFSTORE, /// SENTINEL GET MASTER ADDR BY NAME, SENTINEL MASTER, SENTINEL MASTERS, SENTINEL MONITOR, diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index 6b3f1ca41..b0aba7dce 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -312,6 +312,59 @@ where }) } +/// Takes the given `routable` with possibly multiple keys and creates a single-slot routing info. +/// This is used for commands like PFCOUNT or PFMERGE, where it is required that the command's keys +/// are hashed to the same slots and there is no way how the command might be split on the client. +/// Additionaly this function accepts optional count of provided keys. This is usefull for EVAL and +/// EVALSHA commands which allow user to pass arbitrary number of keys and values as arguments. +/// +/// If all keys are not routed to the same slot, `None` variant is returned and invoking of such +/// command fails with UNROUTABLE_ERROR. +fn multiple_keys_same_slot( + routable: &R, + cmd: &[u8], + first_key_index: usize, + key_limit: Option, + allow_empty_keys: bool, +) -> Option +where + R: Routable + ?Sized, +{ + let is_readonly = is_readonly_cmd(cmd); + let mut slots = HashSet::new(); + let mut key_index = 0; + while let Some(key) = routable.arg_idx(first_key_index + key_index) { + if let Some(limit) = key_limit { + if key_index >= limit { + break; + } + } + + slots.insert(get_slot(key)); + key_index += 1; + } + + if slots.is_empty() && allow_empty_keys { + return Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)); + } + + if slots.len() != 1 { + return None; + } + + let slot = slots.into_iter().next().unwrap(); + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot, + if is_readonly { + SlotAddr::ReplicaOptional + } else { + SlotAddr::Master + }, + )), + )) +} + impl ResponsePolicy { /// Parse the command for the matching response policy. pub fn for_command(cmd: &[u8]) -> Option { @@ -473,18 +526,15 @@ impl RoutingInfo { b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" => multi_shard(r, cmd, 1, false), b"MSET" => multi_shard(r, cmd, 1, true), + b"PFCOUNT" | b"PFMERGE" => multiple_keys_same_slot(r, cmd, 1, None, false), // TODO - special handling - b"SCAN" b"SCAN" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" | b"MOVE" | b"BITOP" => None, b"EVALSHA" | b"EVAL" => { let key_count = r .arg_idx(2) .and_then(|x| std::str::from_utf8(x).ok()) - .and_then(|x| x.parse::().ok())?; - if key_count == 0 { - Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - } else { - r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)) - } + .and_then(|x| x.parse::().ok())?; + multiple_keys_same_slot(r, cmd, 3, Some(key_count), true) } b"XGROUP CREATE" | b"XGROUP CREATECONSUMER" @@ -620,7 +670,7 @@ pub enum SlotAddr { } /// This is just a simplified version of [`Slot`], -/// which stores only the master and [optional] replica +/// which stores only the master and optional replica /// to avoid the need to choose a replica each time /// a command is executed #[derive(Debug)] @@ -820,7 +870,7 @@ mod tests { RoutingInfo, SingleNodeRoutingInfo, Slot, SlotAddr, SlotMap, }; use crate::{ - cluster_routing::{AggregateOp, ResponsePolicy}, + cluster_routing::{get_slot, AggregateOp, ResponsePolicy}, cmd, parser::parse_redis_value, Value, @@ -879,16 +929,6 @@ mod tests { test_cmd.arg("GROUPS").arg("FOOBAR"); test_cmds.push(test_cmd); - // Routing key is 3rd or 4th arg (3rd = "0" == RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - test_cmd = cmd("EVAL"); - test_cmd.arg("FOO").arg("0").arg("BAR"); - test_cmds.push(test_cmd); - - // Routing key is 3rd or 4th arg (3rd != "0" == RoutingInfo::Slot) - test_cmd = cmd("EVAL"); - test_cmd.arg("FOO").arg("4").arg("BAR"); - test_cmds.push(test_cmd); - // Routing key position is variable, 3rd arg test_cmd = cmd("XREAD"); test_cmd.arg("STREAMS").arg("4"); @@ -967,26 +1007,7 @@ mod tests { ); } - for cmd in [ - cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0), - cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0), - ] { - assert_eq!( - RoutingInfo::for_routable(cmd), - Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - ); - } - for (cmd, expected) in [ - ( - cmd("EVAL") - .arg(r#"redis.call("GET, KEYS[1]");"#) - .arg(1) - .arg("foo"), - Some(RoutingInfo::SingleNode( - SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)), - )), - ), ( cmd("XGROUP") .arg("CREATE") @@ -1107,6 +1128,136 @@ mod tests { ); } + #[test] + fn test_multiple_keys_same_slot() { + // single key + let mut cmd = crate::cmd("PFCOUNT"); + cmd.arg("hll-1"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(_, SlotAddr::ReplicaOptional)) + )) + )); + + let mut cmd = crate::cmd("PFMERGE"); + cmd.arg("hll-1"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(_, SlotAddr::Master)) + )) + )); + + // multiple keys + let mut cmd = crate::cmd("PFCOUNT"); + cmd.arg("{hll}-1").arg("{hll}-2"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(_, SlotAddr::ReplicaOptional)) + )) + )); + + let mut cmd = crate::cmd("PFMERGE"); + cmd.arg("{hll}-1").arg("{hll}-2"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(_, SlotAddr::Master)) + )) + )); + + // same-slot violation + let mut cmd = crate::cmd("PFCOUNT"); + cmd.arg("hll-1").arg("hll-2"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(routing.is_none()); + + let mut cmd = crate::cmd("PFMERGE"); + cmd.arg("hll-1").arg("hll-2"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(routing.is_none()); + + // missing keys + let cmd = crate::cmd("PFCOUNT"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(routing.is_none()); + + let cmd = crate::cmd("PFMERGE"); + let routing = RoutingInfo::for_routable(&cmd); + assert!(routing.is_none()); + } + + #[test] + fn test_eval_and_evalsha() { + // no key + let mut cmd = crate::cmd("EVAL"); + cmd.arg(r#""return 42""#).arg("0"); + let routing = RoutingInfo::for_routable(&cmd); + assert_eq!( + routing, + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + ); + + // just keys + let mut cmd = crate::cmd("EVAL"); + cmd.arg(r#""return {KEYS[1], KEYS[2]}""#) + .arg("2") + .arg("{k}1") + .arg("{k}2"); + let routing = RoutingInfo::for_routable(&cmd); + assert_eq!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(get_slot(b"{k}1"), SlotAddr::Master)) + )) + ); + + // just values + let mut cmd = crate::cmd("EVALSHA"); + cmd.arg("c62ff9e46fd2e8a71f74500b9438c80df6af233c") + .arg("0") + .arg("v1") + .arg("v2"); + let routing = RoutingInfo::for_routable(&cmd); + assert_eq!( + routing, + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + ); + + // keys and values + let mut cmd = crate::cmd("EVALSHA"); + cmd.arg("c62ff9e46fd2e8a71f74500b9438c80df6af233c") + .arg("2") + .arg("{k}1") + .arg("{k}2") + .arg("v1") + .arg("v2"); + let routing = RoutingInfo::for_routable(&cmd); + assert_eq!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(get_slot(b"{k}1"), SlotAddr::Master)) + )) + ); + + // same-slot violation + let mut cmd = crate::cmd("EVALSHA"); + cmd.arg("c62ff9e46fd2e8a71f74500b9438c80df6af233c") + .arg("2") + .arg("k1") + .arg("k2") + .arg("v1") + .arg("v2"); + let routing = RoutingInfo::for_routable(&cmd); + assert_eq!(routing, None); + } + #[test] fn test_command_creation_for_multi_shard() { let mut original_cmd = cmd("DEL"); diff --git a/redis/src/commands/mod.rs b/redis/src/commands/mod.rs index ed906e78f..5a716c218 100644 --- a/redis/src/commands/mod.rs +++ b/redis/src/commands/mod.rs @@ -2,8 +2,8 @@ use crate::cmd::{cmd, Cmd, Iter}; use crate::connection::{Connection, ConnectionLike, Msg}; use crate::pipeline::Pipeline; use crate::types::{ - ExistenceCheck, ExpireOption, Expiry, FromRedisValue, NumericBehavior, RedisResult, RedisWrite, - SetExpiry, ToRedisArgs, + ExistenceCheck, ExpireOption, Expiry, FieldExistenceCheck, FromRedisValue, NumericBehavior, + RedisResult, RedisWrite, SetExpiry, ToRedisArgs, }; #[macro_use] @@ -273,15 +273,7 @@ implement_commands! { /// Get the value of a key and set expiration fn get_ex(key: K, expire_at: Expiry) { - let (option, time_arg) = match expire_at { - Expiry::EX(sec) => ("EX", Some(sec)), - Expiry::PX(ms) => ("PX", Some(ms)), - Expiry::EXAT(timestamp_sec) => ("EXAT", Some(timestamp_sec)), - Expiry::PXAT(timestamp_ms) => ("PXAT", Some(timestamp_ms)), - Expiry::PERSIST => ("PERSIST", None), - }; - - cmd("GETEX").arg(key).arg(option).arg(time_arg) + cmd("GETEX").arg(key).arg(expire_at) } /// Get the value of a key and delete it @@ -382,16 +374,31 @@ implement_commands! { cmd(if field.num_of_args() <= 1 { "HGET" } else { "HMGET" }).arg(key).arg(field) } + /// Get the value of one or more fields of a given hash key, and optionally set their expiration + fn hget_ex(key: K, fields: F, expire_at: Expiry) { + cmd("HGETEX").arg(key).arg(expire_at).arg("FIELDS").arg(fields.num_of_args()).arg(fields) + } + /// Deletes a single (or multiple) fields from a hash. fn hdel(key: K, field: F) { cmd("HDEL").arg(key).arg(field) } + /// Get and delete the value of one or more fields of a given hash key + fn hget_del(key: K, fields: F) { + cmd("HGETDEL").arg(key).arg("FIELDS").arg(fields.num_of_args()).arg(fields) + } + /// Sets a single field in a hash. fn hset(key: K, field: F, value: V) { cmd("HSET").arg(key).arg(field).arg(value) } + /// Set the value of one or more fields of a given hash key, and optionally set their expiration + fn hset_ex(key: K, hash_field_expiration_options: &'a HashFieldExpirationOptions, fields_values: &'a [(F, V)]) { + cmd("HSETEX").arg(key).arg(hash_field_expiration_options).arg("FIELDS").arg(fields_values.len()).arg(fields_values) + } + /// Sets a single field in a hash if it does not exist. fn hset_nx(key: K, field: F, value: V) { cmd("HSETNX").arg(key).arg(field).arg(value) @@ -2208,6 +2215,54 @@ assert_eq!(b, 5); fn invoke_script<>(invocation: &'a crate::ScriptInvocation<'a>) { &mut invocation.eval_cmd() } + + // cleanup commands + + /// Deletes all the keys of all databases + /// + /// Whether the flushing happens asynchronously or synchronously depends on the configuration + /// of your Redis server. + /// + /// To enforce a flush mode, use [`Commands::flushall_options`]. + /// + /// ```text + /// FLUSHALL + /// ``` + fn flushall<>() { + &mut cmd("FLUSHALL") + } + + /// Deletes all the keys of all databases with options + /// + /// ```text + /// FLUSHALL [ASYNC|SYNC] + /// ``` + fn flushall_options<>(options: &'a FlushAllOptions) { + cmd("FLUSHALL").arg(options) + } + + /// Deletes all the keys of the current database + /// + /// Whether the flushing happens asynchronously or synchronously depends on the configuration + /// of your Redis server. + /// + /// To enforce a flush mode, use [`Commands::flushdb_options`]. + /// + /// ```text + /// FLUSHDB + /// ``` + fn flushdb<>() { + &mut cmd("FLUSHDB") + } + + /// Deletes all the keys of the current database with options + /// + /// ```text + /// FLUSHDB [ASYNC|SYNC] + /// ``` + fn flushdb_options<>(options: &'a FlushDbOptions) { + cmd("FLUSHDB").arg(options) + } } /// Allows pubsub callbacks to stop receiving messages. @@ -2592,6 +2647,136 @@ impl ToRedisArgs for SetOptions { } } +/// Options for the [FLUSHALL](https://redis.io/commands/flushall) command +/// +/// # Example +/// ```rust,no_run +/// use redis::{Commands, RedisResult, FlushAllOptions}; +/// fn flushall_sync( +/// con: &mut redis::Connection, +/// ) -> RedisResult<()> { +/// let opts = FlushAllOptions{blocking: true}; +/// con.flushall_options(&opts) +/// } +/// ``` +#[derive(Clone, Copy, Default)] +pub struct FlushAllOptions { + /// Blocking (`SYNC`) waits for completion, non-blocking (`ASYNC`) runs in the background + pub blocking: bool, +} + +impl FlushAllOptions { + /// Set whether to run blocking (`SYNC`) or non-blocking (`ASYNC`) flush + pub fn blocking(mut self, blocking: bool) -> Self { + self.blocking = blocking; + self + } +} + +impl ToRedisArgs for FlushAllOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if self.blocking { + out.write_arg(b"SYNC"); + } else { + out.write_arg(b"ASYNC"); + }; + } +} + +/// Options for the [FLUSHDB](https://redis.io/commands/flushdb) command +pub type FlushDbOptions = FlushAllOptions; + +/// Options for the HSETEX command +#[derive(Clone, Copy, Default)] +pub struct HashFieldExpirationOptions { + existence_check: Option, + expiration: Option, +} + +impl HashFieldExpirationOptions { + /// Set the field(s) existence check for the HSETEX command + pub fn set_existence_check(mut self, field_existence_check: FieldExistenceCheck) -> Self { + self.existence_check = Some(field_existence_check); + self + } + + /// Set the expiration option for the field(s) in the HSETEX command + pub fn set_expiration(mut self, expiration: SetExpiry) -> Self { + self.expiration = Some(expiration); + self + } +} + +impl ToRedisArgs for HashFieldExpirationOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref existence_check) = self.existence_check { + match existence_check { + FieldExistenceCheck::FNX => out.write_arg(b"FNX"), + FieldExistenceCheck::FXX => out.write_arg(b"FXX"), + } + } + + if let Some(ref expiration) = self.expiration { + match expiration { + SetExpiry::EX(secs) => { + out.write_arg(b"EX"); + out.write_arg(format!("{}", secs).as_bytes()); + } + SetExpiry::PX(millis) => { + out.write_arg(b"PX"); + out.write_arg(format!("{}", millis).as_bytes()); + } + SetExpiry::EXAT(unix_time) => { + out.write_arg(b"EXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::PXAT(unix_time) => { + out.write_arg(b"PXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::KEEPTTL => { + out.write_arg(b"KEEPTTL"); + } + } + } + } +} + +impl ToRedisArgs for Expiry { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + match self { + Expiry::EX(sec) => { + out.write_arg(b"EX"); + out.write_arg(sec.to_string().as_bytes()); + } + Expiry::PX(ms) => { + out.write_arg(b"PX"); + out.write_arg(ms.to_string().as_bytes()); + } + Expiry::EXAT(timestamp_sec) => { + out.write_arg(b"EXAT"); + out.write_arg(timestamp_sec.to_string().as_bytes()); + } + Expiry::PXAT(timestamp_ms) => { + out.write_arg(b"PXAT"); + out.write_arg(timestamp_ms.to_string().as_bytes()); + } + Expiry::PERSIST => { + out.write_arg(b"PERSIST"); + } + } + } +} + /// Creates HELLO command for RESP3 with RedisConnectionInfo pub fn resp3_hello(connection_info: &RedisConnectionInfo) -> Cmd { let mut hello_cmd = cmd("HELLO"); diff --git a/redis/src/connection.rs b/redis/src/connection.rs index b8c634d39..e79e0308f 100644 --- a/redis/src/connection.rs +++ b/redis/src/connection.rs @@ -40,13 +40,19 @@ use crate::PushInfo; use rustls_native_certs::load_native_certs; #[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; +use crate::tls::ClientTlsParams; // Non-exhaustive to prevent construction outside this crate -#[cfg(not(feature = "tls-rustls"))] #[derive(Clone, Debug)] #[non_exhaustive] -pub struct TlsConnParams; +pub struct TlsConnParams { + #[cfg(feature = "tls-rustls")] + pub(crate) client_tls_params: Option, + #[cfg(feature = "tls-rustls")] + pub(crate) root_cert_store: Option, + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + pub(crate) danger_accept_invalid_hostnames: bool, +} static DEFAULT_PORT: u16 = 6379; @@ -69,7 +75,9 @@ fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result Option { match url::Url::parse(input) { Ok(result) => match result.scheme() { - "redis" | "rediss" | "redis+unix" | "unix" => Some(result), + "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => { + Some(result) + } _ => None, }, Err(_) => None, @@ -151,10 +159,13 @@ impl ConnectionAddr { /// Checks if this address is supported. /// /// Because not all platforms support all connection addresses this is a - /// quick way to figure out if a connection method is supported. Currently - /// this only affects unix connections which are only supported on unix - /// platforms and on older versions of rust also require an explicit feature - /// to be enabled. + /// quick way to figure out if a connection method is supported. Currently + /// this affects: + /// + /// - Unix socket addresses, which are supported only on Unix + /// + /// - TLS addresses, which are supported only if a TLS feature is enabled + /// (either `tls-native-tls` or `tls-rustls`). pub fn is_supported(&self) -> bool { match *self { ConnectionAddr::Tcp(_, _) => true, @@ -165,6 +176,31 @@ impl ConnectionAddr { } } + /// Configure this address to connect without checking certificate hostnames. + /// + /// # Warning + /// + /// You should think very carefully before you use this method. If hostname + /// verification is not used, any valid certificate for any site will be + /// trusted for use from any other. This introduces a significant + /// vulnerability to man-in-the-middle attacks. + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) { + if let ConnectionAddr::TcpTls { tls_params, .. } = self { + if let Some(ref mut params) = tls_params { + params.danger_accept_invalid_hostnames = insecure; + } else if insecure { + *tls_params = Some(TlsConnParams { + #[cfg(feature = "tls-rustls")] + client_tls_params: None, + #[cfg(feature = "tls-rustls")] + root_cert_store: None, + danger_accept_invalid_hostnames: insecure, + }); + } + } + } + #[cfg(feature = "cluster")] pub(crate) fn tls_mode(&self) -> Option { match self { @@ -236,7 +272,7 @@ impl IntoConnectionInfo for ConnectionInfo { } } -/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// URL format: `{redis|rediss|valkey|valkeys}://[][:@][:port][/]` /// /// - Basic: `redis://127.0.0.1:6379` /// - Username & Password: `redis://user:password@127.0.0.1:6379` @@ -266,7 +302,7 @@ where } } -/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// URL format: `{redis|rediss|valkey|valkeys}://[][:@][:port][/]` /// /// - Basic: `redis://127.0.0.1:6379` /// - Username & Password: `redis://user:password@127.0.0.1:6379` @@ -326,7 +362,7 @@ fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")), }; let port = url.port().unwrap_or(DEFAULT_PORT); - let addr = if url.scheme() == "rediss" { + let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" { #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] { match url.fragment() { @@ -426,8 +462,8 @@ fn url_to_unix_connection_info(_: url::Url) -> RedisResult { impl IntoConnectionInfo for url::Url { fn into_connection_info(self) -> RedisResult { match self.scheme() { - "redis" | "rediss" => url_to_tcp_connection_info(self), - "unix" | "redis+unix" => url_to_unix_connection_info(self), + "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self), + "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self), _ => fail!(( ErrorKind::InvalidClientConfig, "URL provided is not a redis URL" @@ -517,6 +553,84 @@ impl fmt::Debug for NoCertificateVerification { } } +/// Insecure `ServerCertVerifier` for rustls that implements `danger_accept_invalid_hostnames`. +#[cfg(feature = "tls-rustls-insecure")] +#[derive(Debug)] +struct AcceptInvalidHostnamesCertVerifier { + inner: Arc, +} + +#[cfg(feature = "tls-rustls-insecure")] +fn is_hostname_error(err: &rustls::Error) -> bool { + matches!( + err, + rustls::Error::InvalidCertificate( + rustls::CertificateError::NotValidForName + | rustls::CertificateError::NotValidForNameContext { .. } + ) + ) +} + +#[cfg(feature = "tls-rustls-insecure")] +impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier { + fn verify_server_cert( + &self, + end_entity: &rustls::pki_types::CertificateDer<'_>, + intermediates: &[rustls::pki_types::CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: rustls::pki_types::UnixTime, + ) -> Result { + self.inner + .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) + .or_else(|err| { + if is_hostname_error(&err) { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } else { + Err(err) + } + }) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls::pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner + .verify_tls12_signature(message, cert, dss) + .or_else(|err| { + if is_hostname_error(&err) { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } else { + Err(err) + } + }) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls::pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner + .verify_tls13_signature(message, cert, dss) + .or_else(|err| { + if is_hostname_error(&err) { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } else { + Err(err) + } + }) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + /// Represents a stateful redis TCP connection. pub struct Connection { con: ActualConnection, @@ -599,7 +713,7 @@ impl ActualConnection { ref host, port, insecure, - .. + ref tls_params, } => { let tls_connector = if insecure { TlsConnector::builder() @@ -607,6 +721,10 @@ impl ActualConnection { .danger_accept_invalid_hostnames(true) .use_sni(false) .build()? + } else if let Some(params) = tls_params { + TlsConnector::builder() + .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames) + .build()? } else { TlsConnector::new()? }; @@ -846,8 +964,6 @@ pub(crate) fn create_rustls_config( insecure: bool, tls_params: Option, ) -> RedisResult { - use crate::tls::ClientTlsParams; - #[allow(unused_mut)] let mut root_store = RootCertStore::empty(); #[cfg(feature = "tls-rustls-webpki-roots")] @@ -857,16 +973,22 @@ pub(crate) fn create_rustls_config( not(feature = "tls-native-tls"), not(feature = "tls-rustls-webpki-roots") ))] - for cert in load_native_certs()? { - root_store.add(cert)?; + { + let mut certificate_result = load_native_certs(); + if let Some(error) = certificate_result.errors.pop() { + return Err(error.into()); + } + for cert in certificate_result.certs { + root_store.add(cert)?; + } } let config = rustls::ClientConfig::builder(); let config = if let Some(tls_params) = tls_params { - let config_builder = - config.with_root_certificates(tls_params.root_cert_store.unwrap_or(root_store)); + let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store); + let config_builder = config.with_root_certificates(root_cert_store.clone()); - if let Some(ClientTlsParams { + let config_builder = if let Some(ClientTlsParams { client_cert_chain: client_cert, client_key, }) = tls_params.client_tls_params @@ -882,7 +1004,44 @@ pub(crate) fn create_rustls_config( })? } else { config_builder.with_no_client_auth() - } + }; + + // Implement `danger_accept_invalid_hostnames`. + // + // The strange cfg here is to handle a specific unusual combination of features: if + // `tls-native-tls` and `tls-rustls` are enabled, but `tls-rustls-insecure` is not, and the + // application tries to use the danger flag. + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames { + #[cfg(not(feature = "tls-rustls-insecure"))] + { + // This code should not enable an insecure mode if the `insecure` feature is not + // set, but it shouldn't silently ignore the flag either. So return an error. + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature" + )); + } + + #[cfg(feature = "tls-rustls-insecure")] + { + let mut config = config_builder; + config.dangerous().set_certificate_verifier(Arc::new( + AcceptInvalidHostnamesCertVerifier { + inner: rustls::client::WebPkiServerVerifier::builder(Arc::new( + root_cert_store, + )) + .build() + .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?, + }, + )); + config + } + } else { + config_builder + }; + + config_builder } else { config .with_root_certificates(root_store) @@ -1178,7 +1337,7 @@ fn execute_connection_pipeline( rv: &mut Connection, (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents), ) -> RedisResult { - if pipeline.len() == 0 { + if pipeline.is_empty() { return Ok(AuthResult::Succeeded); } let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?; @@ -2063,7 +2222,14 @@ mod tests { let cases = vec![ ("redis://127.0.0.1", true), ("redis://[::1]", true), + ("rediss://127.0.0.1", true), + ("rediss://[::1]", true), + ("valkey://127.0.0.1", true), + ("valkey://[::1]", true), + ("valkeys://127.0.0.1", true), + ("valkeys://[::1]", true), ("redis+unix:///run/redis.sock", true), + ("valkey+unix:///run/valkey.sock", true), ("unix:///run/redis.sock", true), ("http://127.0.0.1", false), ("tcp://127.0.0.1", false), diff --git a/redis/src/io/dns.rs b/redis/src/io/dns.rs new file mode 100644 index 000000000..ff3d42f90 --- /dev/null +++ b/redis/src/io/dns.rs @@ -0,0 +1,13 @@ +#[cfg(feature = "aio")] +use std::net::SocketAddr; + +/// An async DNS resovler for resolving redis domain. +#[cfg(feature = "aio")] +pub trait AsyncDNSResolver: Send + Sync + 'static { + /// Resolves the host and port to a list of `SocketAddr`. + fn resolve<'a, 'b: 'a>( + &'a self, + host: &'b str, + port: u16, + ) -> crate::RedisFuture<'a, Box + Send + 'a>>; +} diff --git a/redis/src/io/mod.rs b/redis/src/io/mod.rs index fd54142a5..39b3ebdbc 100644 --- a/redis/src/io/mod.rs +++ b/redis/src/io/mod.rs @@ -1,2 +1,8 @@ /// Module for defining the TCP settings and behavior. pub mod tcp; + +#[cfg(feature = "aio")] +mod dns; + +#[cfg(feature = "aio")] +pub use dns::AsyncDNSResolver; diff --git a/redis/src/io/tcp.rs b/redis/src/io/tcp.rs index 48701e005..ea54b5eb6 100644 --- a/redis/src/io/tcp.rs +++ b/redis/src/io/tcp.rs @@ -1,4 +1,9 @@ -use std::{io, net::TcpStream, time::Duration}; +use std::{io, net::TcpStream}; + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +use std::time::Duration; + +pub use socket2; /// Settings for a TCP stream. #[derive(Clone, Debug)] @@ -35,6 +40,7 @@ impl TcpSettings { } } +#[allow(clippy::derivable_impls)] impl Default for TcpSettings { fn default() -> Self { Self { diff --git a/redis/src/lib.rs b/redis/src/lib.rs index b9b8538ee..2ddcd9d8c 100644 --- a/redis/src/lib.rs +++ b/redis/src/lib.rs @@ -28,7 +28,7 @@ //! //! The user can enable TLS support using either RusTLS or native support (usually OpenSSL), //! using the `tls-rustls` or `tls-native-tls` features respectively. In order to enable TLS -//! for async usage, the user must enable matching features for their runtime - either `tokio-native-tls-comp``, +//! for async usage, the user must enable matching features for their runtime - either `tokio-native-tls-comp`, //! `tokio-rustls-comp`, `async-std-native-tls-comp`, or `async-std-rustls-comp`. Additionally, the //! `tls-rustls-webpki-roots` allows usage of of webpki-roots for the root certificate store. //! @@ -92,7 +92,8 @@ //! //! * `acl`: enables acl support (enabled by default) //! * `tokio-comp`: enables support for async usage with the Tokio runtime (optional) -//! * `async-std-comp`: enables support for async usage with any runtime which is async-std compliant, such as Smol. (optional) +//! * `async-std-comp`: enables support for async usage with any runtime which is async-std compliant. (optional) +//! * `smol-comp`: enables support for async usage with the Smol runtime (optional) //! * `geospatial`: enables geospatial support (enabled by default) //! * `script`: enables script support (enabled by default) //! * `streams`: enables high-level interface for interaction with Redis streams (enabled by default) @@ -108,6 +109,7 @@ //! * `sentinel`: enables high-level interfaces for communication with Redis sentinels (optional) //! * `json`: enables high-level interfaces for communication with the JSON module (optional) //! * `cache-aio`: enables **experimental** client side caching for MultiplexedConnection (optional) +//! * `disable-client-setinfo`: disables the `CLIENT SETINFO` handshake during connection initialization //! //! ## Connection Parameters //! @@ -436,7 +438,7 @@ it will not automatically be loaded and retried. The script can be loaded using # Async In addition to the synchronous interface that's been explained above there also exists an -asynchronous interface based on [`futures`][] and [`tokio`][], or [`async-std`][]. +asynchronous interface based on [`futures`][] and [`tokio`][], [`smol`](https://docs.rs/smol/latest/smol/), or [`async-std`][]. This interface exists under the `aio` (async io) module (which requires that the `aio` feature is enabled) and largely mirrors the synchronous with a few concessions to make it fit the @@ -462,6 +464,14 @@ let result = redis::cmd("MGET") assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); # Ok(()) } ``` + +## Runtime support +The crate supports multiple runtimes, including `tokio`, `async-std`, and `smol`. For Tokio, the crate will +spawn tasks on the current thread runtime. For async-std & smol, the crate will spawn tasks on the the global runtime. +It is recommended that the crate be used with support only for a single runtime. If the crate is compiled with multiple runtimes, +the user should call [`crate::aio::prefer_tokio`], [`crate::aio::prefer_async_std`] or [`crate::aio::prefer_smol`] to set the preferred runtime. +These functions set global state which automatically chooses the correct runtime for the async connection. + "## )] //! @@ -533,7 +543,8 @@ pub use crate::client::Client; pub use crate::cmd::CommandCacheConfig; pub use crate::cmd::{cmd, pack_command, pipe, Arg, Cmd, Iter}; pub use crate::commands::{ - Commands, ControlFlow, Direction, LposOptions, PubSubCommands, ScanOptions, SetOptions, + Commands, ControlFlow, Direction, FlushAllOptions, FlushDbOptions, HashFieldExpirationOptions, + LposOptions, PubSubCommands, ScanOptions, SetOptions, }; pub use crate::connection::{ parse_redis_url, transaction, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, @@ -566,6 +577,7 @@ pub use crate::types::{ Expiry, SetExpiry, ExistenceCheck, + FieldExistenceCheck, ExpireOption, Role, ReplicaInfo, @@ -640,6 +652,10 @@ pub mod cluster_routing; #[cfg_attr(docsrs, doc(cfg(feature = "r2d2")))] mod r2d2; +#[cfg(all(feature = "bb8", feature = "aio"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "bb8", feature = "aio"))))] +mod bb8; + #[cfg(feature = "streams")] #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] pub mod streams; diff --git a/redis/src/pipeline.rs b/redis/src/pipeline.rs index 398aa0ff6..356512f81 100644 --- a/redis/src/pipeline.rs +++ b/redis/src/pipeline.rs @@ -82,13 +82,16 @@ impl Pipeline { encode_pipeline(&self.commands, self.transaction_mode) } - #[cfg(feature = "aio")] - pub(crate) fn write_packed_pipeline(&self, out: &mut Vec) { - write_pipeline(out, &self.commands, self.transaction_mode) + /// Returns the number of commands currently queued by the usr in the pipeline. + /// + /// Depending on its configuration (e.g. `atomic`), the pipeline may send more commands to the server than the returned length + pub fn len(&self) -> usize { + self.commands.len() } - pub(crate) fn len(&self) -> usize { - self.commands.len() + /// Returns `true` is the pipeline contains no elements. + pub fn is_empty(&self) -> bool { + self.commands.is_empty() } fn execute_pipelined(&self, con: &mut dyn ConnectionLike) -> RedisResult { diff --git a/redis/src/sentinel.rs b/redis/src/sentinel.rs index f2e1a8168..ed6d48a1a 100644 --- a/redis/src/sentinel.rs +++ b/redis/src/sentinel.rs @@ -125,6 +125,8 @@ //! ``` //! +#[cfg(feature = "aio")] +use crate::aio::MultiplexedConnection as AsyncConnection; #[cfg(feature = "aio")] use futures_util::StreamExt; use rand::Rng; @@ -133,11 +135,9 @@ use std::sync::Mutex; use std::{collections::HashMap, num::NonZeroUsize}; #[cfg(feature = "aio")] -use crate::aio::MultiplexedConnection as AsyncConnection; - +use crate::aio::MultiplexedConnection; #[cfg(feature = "aio")] use crate::client::AsyncConnectionConfig; - #[cfg(feature = "tls-rustls")] use crate::tls::retrieve_tls_certificates; #[cfg(feature = "tls-rustls")] @@ -145,7 +145,7 @@ use crate::TlsCertificates; use crate::{ connection::ConnectionInfo, types::RedisResult, Client, Cmd, Connection, ConnectionAddr, ErrorKind, FromRedisValue, IntoConnectionInfo, ProtocolVersion, RedisConnectionInfo, - RedisError, TlsMode, Value, + RedisError, Role, TlsMode, }; /// The Sentinel type, serves as a special purpose client which builds other clients on @@ -295,25 +295,57 @@ fn valid_addrs<'a>( }) } -fn check_role_result(result: &RedisResult>, target_role: &str) -> bool { - if let Ok(values) = result { - if !values.is_empty() { - if let Ok(role) = String::from_redis_value(&values[0]) { - return role.to_ascii_lowercase() == target_role; - } - } +fn determine_master_from_role_or_info_replication( + connection_info: &ConnectionInfo, +) -> RedisResult { + let client = Client::open(connection_info.clone())?; + let mut conn = client.get_connection()?; + + //Once the client discovered the address of the master instance, it should attempt a connection with the master, and call the ROLE command in order to verify the role of the instance is actually a master. + let role = check_role(&mut conn); + if role.is_ok_and(|x| matches!(x, Role::Primary { .. })) { + return Ok(true); + } + + //If the ROLE commands is not available (it was introduced in Redis 2.8.12), a client may resort to the INFO replication command parsing the role: field of the output. + let role = check_info_replication(&mut conn); + if role.is_ok_and(|x| x == "master") { + return Ok(true); } - false + + //TODO: Maybe there should be some kind of error message if both role checks fail due to ACL permissions? + Ok(false) } -fn check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { - if let Ok(client) = Client::open(connection_info.clone()) { - if let Ok(mut conn) = client.get_connection() { - let result: RedisResult> = crate::cmd("ROLE").query(&mut conn); - return check_role_result(&result, target_role); - } +fn get_node_role(connection_info: &ConnectionInfo) -> RedisResult { + let client = Client::open(connection_info.clone())?; + let mut conn = client.get_connection()?; + crate::cmd("ROLE").query(&mut conn) +} + +fn check_role(conn: &mut Connection) -> RedisResult { + crate::cmd("ROLE").query(conn) +} + +fn check_info_replication(conn: &mut Connection) -> RedisResult { + let info: String = crate::cmd("INFO").arg("REPLICATION").query(conn)?; + + //Taken from test_sentinel parse_replication_info + let info_map = parse_replication_info(info); + match info_map.get("role") { + Some(x) => Ok(x.clone()), + None => Err(RedisError::from((ErrorKind::ParseError, "parse error"))), } - false +} + +fn parse_replication_info(value: String) -> HashMap { + let info_map: std::collections::HashMap = value + .split("\r\n") + .filter(|line| !line.trim_start().starts_with('#')) + .filter_map(|line| line.split_once(':')) + .map(|(key, val)| (key.to_string(), val.to_string())) + .collect(); + info_map } /// Searches for a valid master with the given name in the list of masters returned by @@ -333,7 +365,7 @@ fn find_valid_master( let connection_info = node_connection_info.create_connection_info(ip, port)?; #[cfg(feature = "tls-rustls")] let connection_info = node_connection_info.create_connection_info(ip, port, certs)?; - if check_role(&connection_info, "master") { + if determine_master_from_role_or_info_replication(&connection_info).is_ok_and(|x| x) { return Ok(connection_info); } } @@ -345,14 +377,69 @@ fn find_valid_master( } #[cfg(feature = "aio")] -async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { - if let Ok(client) = Client::open(connection_info.clone()) { - if let Ok(mut conn) = client.get_multiplexed_async_connection().await { - let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; - return check_role_result(&result, target_role); - } +async fn async_determine_master_from_role_or_info_replication( + connection_info: &ConnectionInfo, +) -> RedisResult { + let client = Client::open(connection_info.clone())?; + let mut conn = client.get_multiplexed_async_connection().await?; + + //Once the client discovered the address of the master instance, it should attempt a connection with the master, and call the ROLE command in order to verify the role of the instance is actually a master. + let role = async_check_role(&mut conn).await; + if role.is_ok_and(|x| matches!(x, Role::Primary { .. })) { + return Ok(true); + } + + //If the ROLE commands is not available (it was introduced in Redis 2.8.12), a client may resort to the INFO replication command parsing the role: field of the output. + let role = async_check_info_replication(&mut conn).await; + if role.is_ok_and(|x| x == "master") { + return Ok(true); + } + + //TODO: Maybe there should be some kind of error message if both role checks fail due to ACL permissions? + Ok(false) +} + +#[cfg(feature = "aio")] +async fn async_determine_slave_from_role_or_info_replication( + connection_info: &ConnectionInfo, +) -> RedisResult { + let client = Client::open(connection_info.clone())?; + let mut conn = client.get_multiplexed_async_connection().await?; + + //Once the client discovered the address of the master instance, it should attempt a connection with the master, and call the ROLE command in order to verify the role of the instance is actually a master. + let role = async_check_role(&mut conn).await; + if role.is_ok_and(|x| matches!(x, Role::Replica { .. })) { + return Ok(true); + } + + //If the ROLE commands is not available (it was introduced in Redis 2.8.12), a client may resort to the INFO replication command parsing the role: field of the output. + let role = async_check_info_replication(&mut conn).await; + if role.is_ok_and(|x| x == "slave") { + return Ok(true); + } + + //TODO: Maybe there should be some kind of error message if both role checks fail due to ACL permissions? + Ok(false) +} + +#[cfg(feature = "aio")] +async fn async_check_role(conn: &mut MultiplexedConnection) -> RedisResult { + let role: RedisResult = crate::cmd("ROLE").query_async(conn).await; + role +} + +#[cfg(feature = "aio")] +async fn async_check_info_replication(conn: &mut MultiplexedConnection) -> RedisResult { + let info: String = crate::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await?; + //Taken from test_sentinel parse_replication_info + let info_map = parse_replication_info(info); + match info_map.get("role") { + Some(x) => Ok(x.clone()), + None => Err(RedisError::from((ErrorKind::ParseError, "parse error"))), } - false } /// Async version of [find_valid_master]. @@ -368,7 +455,10 @@ async fn async_find_valid_master( let connection_info = node_connection_info.create_connection_info(ip, port)?; #[cfg(feature = "tls-rustls")] let connection_info = node_connection_info.create_connection_info(ip, port, certs)?; - if async_check_role(&connection_info, "master").await { + if async_determine_master_from_role_or_info_replication(&connection_info) + .await + .is_ok_and(|x| x) + { return Ok(connection_info); } } @@ -390,7 +480,9 @@ fn get_valid_replicas_addresses( Ok(addresses .into_iter() - .filter(|connection_info| check_role(connection_info, "slave")) + .filter(|connection_info| { + get_node_role(connection_info).is_ok_and(|x| matches!(x, Role::Replica { .. })) + }) .collect()) } @@ -406,7 +498,9 @@ fn get_valid_replicas_addresses( Ok(addresses .into_iter() - .filter(|connection_info| check_role(connection_info, "slave")) + .filter(|connection_info| { + get_node_role(connection_info).is_ok_and(|x| matches!(x, Role::Replica { .. })) + }) .collect()) } @@ -416,10 +510,15 @@ async fn async_get_valid_replicas_addresses( node_connection_info: &SentinelNodeConnectionInfo, ) -> RedisResult> { async fn is_replica_role_valid(connection_info: ConnectionInfo) -> Option { - if async_check_role(&connection_info, "slave").await { - Some(connection_info) - } else { - None + match async_determine_slave_from_role_or_info_replication(&connection_info).await { + Ok(x) => { + if x { + Some(connection_info) + } else { + None + } + } + Err(_e) => None, } } @@ -440,10 +539,15 @@ async fn async_get_valid_replicas_addresses( certs: &Option, ) -> RedisResult> { async fn is_replica_role_valid(connection_info: ConnectionInfo) -> Option { - if async_check_role(&connection_info, "slave").await { - Some(connection_info) - } else { - None + match async_determine_slave_from_role_or_info_replication(&connection_info).await { + Ok(x) => { + if x { + Some(connection_info) + } else { + None + } + } + Err(_e) => None, } } @@ -996,7 +1100,7 @@ impl SentinelClient { /// Returns an async connection from the client, using the same logic from /// `SentinelClient::get_connection`. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "aio")] pub async fn get_async_connection(&mut self) -> RedisResult { self.get_async_connection_with_config(&AsyncConnectionConfig::new()) .await @@ -1004,7 +1108,7 @@ impl SentinelClient { /// Returns an async connection from the client with options, using the same logic from /// `SentinelClient::get_connection`. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "aio")] pub async fn get_async_connection_with_config( &mut self, config: &AsyncConnectionConfig, @@ -1119,7 +1223,10 @@ impl SentinelClientBuilder { host: _, port: _, ref mut insecure, + #[cfg(feature = "tls-rustls")] ref mut tls_params, + #[cfg(not(feature = "tls-rustls"))] + tls_params: _, } => { if let Some(tls_mode) = self.client_to_sentinel_params.tls_mode { match tls_mode { diff --git a/redis/src/tls.rs b/redis/src/tls.rs index 4adbe2ac5..caf07685a 100644 --- a/redis/src/tls.rs +++ b/redis/src/tls.rs @@ -4,6 +4,7 @@ use rustls::pki_types::pem::PemObject; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::RootCertStore; +use crate::connection::TlsConnParams; use crate::{Client, ConnectionAddr, ConnectionInfo, ErrorKind, RedisError, RedisResult}; /// Structure to hold mTLS client _certificate_ and _key_ binaries in PEM format @@ -119,6 +120,8 @@ pub(crate) fn retrieve_tls_certificates( Ok(TlsConnParams { client_tls_params, root_cert_store, + #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))] + danger_accept_invalid_hostnames: false, }) } @@ -143,9 +146,3 @@ impl Clone for ClientTlsParams { } } } - -#[derive(Debug, Clone)] -pub struct TlsConnParams { - pub(crate) client_tls_params: Option, - pub(crate) root_cert_store: Option, -} diff --git a/redis/src/types.rs b/redis/src/types.rs index 1d80dc43d..678bf247c 100644 --- a/redis/src/types.rs +++ b/redis/src/types.rs @@ -68,6 +68,15 @@ pub enum ExistenceCheck { XX, } +/// Helper enum that is used to define field existence checks +#[derive(Clone, Copy)] +pub enum FieldExistenceCheck { + /// FNX -- Only set the fields if all do not already exist. + FNX, + /// FXX -- Only set the fields if all already exist. + FXX, +} + /// Helper enum that is used in some situations to describe /// the behavior of arguments in a numeric context. #[derive(PartialEq, Eq, Clone, Debug, Copy)] @@ -690,6 +699,19 @@ impl From for RedisError { } } +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls_native_certs::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "Fetch certs Error", + err.to_string(), + ), + } + } +} + #[cfg(feature = "uuid")] impl From for RedisError { fn from(err: uuid::Error) -> RedisError { @@ -2499,9 +2521,7 @@ macro_rules! from_redis_value_for_tuple { Value::Array(ch) => { if let [$($name),*] = &ch[..] { rv.push(($(from_redis_value(&$name)?),*),) - } else { - unreachable!() - }; + }; }, _ => {}, @@ -2535,15 +2555,12 @@ macro_rules! from_redis_value_for_tuple { return Ok(rv) } //It's uglier then before! - for item in items.iter() { + for item in items.iter_mut() { match item { - Value::Array(ch) => { - // TODO - this copies when we could've used the owned value. need to find out how to do this. - if let [$($name),*] = &ch[..] { - rv.push(($(from_redis_value($name)?),*),) - } else { - unreachable!() - }; + Value::Array(ref mut ch) => { + if let [$($name),*] = &mut ch[..] { + rv.push(($(from_owned_redis_value(std::mem::replace($name, Value::Nil))?),*),); + }; }, _ => {}, } diff --git a/redis/tests/parser.rs b/redis/tests/parser.rs index 0ba6350bc..30808c80b 100644 --- a/redis/tests/parser.rs +++ b/redis/tests/parser.rs @@ -12,7 +12,7 @@ use { }; mod support; -use crate::support::{block_on_all, encode_value}; +use crate::support::{current_thread_runtime, encode_value}; #[derive(Clone, Debug)] struct ArbitraryValue(Value); @@ -188,7 +188,7 @@ quickcheck! { let mut partial_reader = PartialAsyncRead { inner: &mut reader, ops: Box::new(seq.into_iter()) }; let mut decoder = combine::stream::Decoder::new(); - let result = block_on_all(redis::parse_redis_value_async(&mut decoder, &mut partial_reader), support::RuntimeType::Tokio); + let result = current_thread_runtime().block_on(redis::parse_redis_value_async(&mut decoder, &mut partial_reader)); assert!(result.as_ref().is_ok(), "{}", result.unwrap_err()); assert_eq!( result.unwrap(), diff --git a/redis/tests/support/cluster.rs b/redis/tests/support/cluster.rs index 86952eb9b..2e45c1c18 100644 --- a/redis/tests/support/cluster.rs +++ b/redis/tests/support/cluster.rs @@ -12,6 +12,8 @@ use redis::aio::ConnectionLike; use redis::cluster_async::Connect; use redis::ConnectionInfo; use redis::ProtocolVersion; +#[cfg(feature = "tls-rustls")] +use redis_test::cluster::ClusterType; use redis_test::cluster::{RedisCluster, RedisClusterConfiguration}; use redis_test::server::{use_protocol, RedisServer}; @@ -28,13 +30,28 @@ pub struct TestClusterContext { impl TestClusterContext { pub fn new() -> TestClusterContext { - Self::new_with_config(RedisClusterConfiguration::default()) + Self::new_with_config(RedisClusterConfiguration { + tls_insecure: false, + ..Default::default() + }) } pub fn new_with_mtls() -> TestClusterContext { Self::new_with_config_and_builder( RedisClusterConfiguration { mtls_enabled: true, + tls_insecure: false, + ..Default::default() + }, + identity, + ) + } + + pub fn new_without_ip_alts() -> TestClusterContext { + Self::new_with_config_and_builder( + RedisClusterConfiguration { + tls_insecure: false, + certs_with_ip_alts: false, ..Default::default() }, identity, @@ -46,6 +63,19 @@ impl TestClusterContext { } pub fn new_with_cluster_client_builder(initializer: F) -> TestClusterContext + where + F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, + { + Self::new_with_config_and_builder( + RedisClusterConfiguration { + tls_insecure: false, + ..Default::default() + }, + initializer, + ) + } + + pub fn new_insecure_with_cluster_client_builder(initializer: F) -> TestClusterContext where F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, { @@ -60,6 +90,8 @@ impl TestClusterContext { F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, { start_tls_crypto_provider(); + #[cfg(feature = "tls-rustls")] + let tls_insecure = cluster_config.tls_insecure; let mtls_enabled = cluster_config.mtls_enabled; let cluster = RedisCluster::new(cluster_config); let initial_nodes: Vec = cluster @@ -70,7 +102,7 @@ impl TestClusterContext { .use_protocol(use_protocol()); #[cfg(feature = "tls-rustls")] - if mtls_enabled { + if mtls_enabled || (ClusterType::get_intended() == ClusterType::TcpTls && !tls_insecure) { if let Some(tls_file_paths) = &cluster.tls_paths { builder = builder.certs(load_certs_from_file(tls_file_paths)); } diff --git a/redis/tests/support/mock_cluster.rs b/redis/tests/support/mock_cluster.rs index c1ec2b957..93bcb2cb9 100644 --- a/redis/tests/support/mock_cluster.rs +++ b/redis/tests/support/mock_cluster.rs @@ -35,10 +35,9 @@ pub struct MockConnection { #[cfg(feature = "cluster-async")] impl cluster_async::Connect for MockConnection { - fn connect<'a, T>( + fn connect_with_config<'a, T>( info: T, - _response_timeout: Duration, - _connection_timeout: Duration, + _config: redis::AsyncConnectionConfig, ) -> RedisFuture<'a, Self> where T: IntoConnectionInfo + Send + 'a, diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index fb46308de..90413aebe 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -4,7 +4,7 @@ use futures::Future; #[cfg(feature = "aio")] use redis::{aio, cmd}; -use redis::{ConnectionAddr, InfoDict, Pipeline, ProtocolVersion, RedisResult, Value}; +use redis::{Commands, ConnectionAddr, InfoDict, Pipeline, ProtocolVersion, RedisResult, Value}; use redis_test::server::{use_protocol, Module, RedisServer}; use redis_test::utils::{get_random_available_port, TlsFilePaths}; #[cfg(feature = "tls-rustls")] @@ -32,9 +32,12 @@ pub fn current_thread_runtime() -> tokio::runtime::Runtime { #[cfg(feature = "aio")] pub enum RuntimeType { + #[cfg(feature = "tokio-comp")] Tokio, #[cfg(feature = "async-std-comp")] AsyncStd, + #[cfg(feature = "smol-comp")] + Smol, } #[cfg(feature = "aio")] @@ -67,14 +70,17 @@ where let f = futures_util::FutureExt::fuse(f); futures::pin_mut!(f, check_future); + let f = async move { + futures::select! {res = f => res, err = check_future => err} + }; + let res = match runtime { - RuntimeType::Tokio => current_thread_runtime().block_on(async { - futures::select! {res = f => res, err = check_future => err} - }), + #[cfg(feature = "tokio-comp")] + RuntimeType::Tokio => block_on_all_using_tokio(f), #[cfg(feature = "async-std-comp")] - RuntimeType::AsyncStd => block_on_all_using_async_std(async move { - futures::select! {res = f => res, err = check_future => err} - }), + RuntimeType::AsyncStd => block_on_all_using_async_std(f), + #[cfg(feature = "smol-comp")] + RuntimeType::Smol => block_on_all_using_smol(f), }; let _ = panic::take_hook(); @@ -87,8 +93,9 @@ where #[cfg(feature = "aio")] #[rstest::rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] +#[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] #[should_panic(expected = "Internal thread panicked")] fn test_block_on_all_panics_from_spawns(#[case] runtime: RuntimeType) { let _ = block_on_all( @@ -104,14 +111,36 @@ fn test_block_on_all_panics_from_spawns(#[case] runtime: RuntimeType) { ); } +#[cfg(feature = "tokio-comp")] +fn block_on_all_using_tokio(f: F) -> F::Output +where + F: Future, +{ + #[cfg(any(feature = "async-std-comp", feature = "smol-comp"))] + redis::aio::prefer_tokio().unwrap(); + current_thread_runtime().block_on(f) +} + #[cfg(feature = "async-std-comp")] fn block_on_all_using_async_std(f: F) -> F::Output where F: Future, { + #[cfg(any(feature = "tokio-comp", feature = "smol-comp"))] + redis::aio::prefer_async_std().unwrap(); async_std::task::block_on(f) } +#[cfg(feature = "smol-comp")] +fn block_on_all_using_smol(f: F) -> F::Output +where + F: Future, +{ + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + redis::aio::prefer_smol().unwrap(); + smol::block_on(f) +} + #[cfg(any(feature = "cluster", feature = "cluster-async"))] mod cluster; @@ -232,7 +261,7 @@ impl TestContext { } } } - redis::cmd("FLUSHDB").exec(&mut con).unwrap(); + con.flushdb::<()>().unwrap(); TestContext { server, @@ -245,12 +274,6 @@ impl TestContext { self.client.get_connection().unwrap() } - #[cfg(feature = "aio")] - #[allow(deprecated)] - pub async fn deprecated_async_connection(&self) -> RedisResult { - self.client.get_async_connection().await - } - #[cfg(feature = "aio")] pub async fn async_connection(&self) -> RedisResult { self.client.get_multiplexed_async_connection().await diff --git a/redis/tests/test_async.rs b/redis/tests/test_async.rs index 53fd3698e..a1cb2f200 100644 --- a/redis/tests/test_async.rs +++ b/redis/tests/test_async.rs @@ -137,8 +137,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_args(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -163,7 +164,7 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] fn test_works_with_paused_time(#[case] runtime: RuntimeType) { test_with_all_connection_types_with_setup( || tokio::time::pause(), @@ -193,8 +194,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_can_authenticate_with_username_and_password(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); block_on_all( @@ -242,8 +244,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_nice_hash_api(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut connection| async move { @@ -267,8 +270,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_nice_hash_api_in_pipe(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut connection| async move { @@ -298,8 +302,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn dont_panic_on_closed_multiplexed_connection(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); let client = ctx.client.clone(); @@ -346,8 +351,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_pipeline_transaction(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -375,8 +381,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_client_tracking_doesnt_block_execution(#[case] runtime: RuntimeType) { //It checks if the library distinguish a push-type message from the others and continues its normal operation. test_with_all_connection_types( @@ -403,8 +410,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_pipeline_transaction_with_errors(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -469,8 +477,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_pipe_over_multiplexed_connection(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -488,8 +497,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_running_multiple_commands(#[case] runtime: RuntimeType) { test_with_all_connection_types( |con| async move { @@ -506,8 +516,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_transaction_multiplexed_connection(#[case] runtime: RuntimeType) { test_with_all_connection_types( |con| async move { @@ -549,8 +560,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_scanning(#[values(2, 1000)] batch_size: usize, #[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -586,8 +598,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_scanning_iterative( #[values(2, 1000)] batch_size: usize, #[case] runtime: RuntimeType, @@ -630,8 +643,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_scanning_stream( #[values(2, 1000)] batch_size: usize, #[case] runtime: RuntimeType, @@ -675,8 +689,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_response_timeout_multiplexed_connection(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); block_on_all( @@ -696,8 +711,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] #[cfg(feature = "script")] fn test_script(#[case] runtime: RuntimeType) { test_with_all_connection_types( @@ -732,8 +748,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] #[cfg(feature = "script")] fn test_script_load(#[case] runtime: RuntimeType) { test_with_all_connection_types( @@ -749,8 +766,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] #[cfg(feature = "script")] fn test_script_returning_complex_type(#[case] runtime: RuntimeType) { test_with_all_connection_types( @@ -773,8 +791,9 @@ mod basic_async { // Allowing `let ()` as `query_async` requires the type it converts the result to. #[allow(clippy::let_unit_value, clippy::iter_nth_zero)] #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn io_error_on_kill_issue_320(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); block_on_all( @@ -801,8 +820,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn invalid_password_issue_343(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); block_on_all( @@ -834,8 +854,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_scan_with_options_works(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -871,8 +892,9 @@ mod basic_async { // Test issue of Stream trait blocking if we try to iterate more than 10 items // https://github.com/mitsuhiko/redis-rs/issues/537 and https://github.com/mitsuhiko/redis-rs/issues/583 #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_issue_stream_blocks(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -896,8 +918,9 @@ mod basic_async { #[rstest] // Test issue of AsyncCommands::scan returning the wrong number of keys // https://github.com/redis-rs/redis-rs/issues/759 - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_issue_async_commands_scan_broken(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut con| async move { @@ -924,8 +947,9 @@ mod basic_async { use super::*; #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn pub_sub_subscription(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); block_on_all( @@ -956,8 +980,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn pub_sub_subscription_to_multiple_channels(#[case] runtime: RuntimeType) { use redis::RedisError; @@ -985,8 +1010,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn pub_sub_unsubscription(#[case] runtime: RuntimeType) { const SUBSCRIPTION_KEY: &str = "phonewave-pub-sub-unsubscription"; @@ -1014,8 +1040,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn can_receive_messages_while_sending_requests_from_split_pub_sub( #[case] runtime: RuntimeType, ) { @@ -1045,8 +1072,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn can_send_ping_on_split_pubsub(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); block_on_all( @@ -1097,8 +1125,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn can_receive_messages_from_split_pub_sub_after_sink_was_dropped( #[case] runtime: RuntimeType, ) { @@ -1129,8 +1158,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn can_receive_messages_from_split_pub_sub_after_into_on_message( #[case] runtime: RuntimeType, ) { @@ -1163,8 +1193,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn cannot_subscribe_on_split_pub_sub_after_stream_was_dropped( #[case] runtime: RuntimeType, ) { @@ -1184,8 +1215,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn automatic_unsubscription(#[case] runtime: RuntimeType) { const SUBSCRIPTION_KEY: &str = "phonewave-automatic-unsubscription"; @@ -1222,8 +1254,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn automatic_unsubscription_on_split(#[case] runtime: RuntimeType) { const SUBSCRIPTION_KEY: &str = "phonewave-automatic-unsubscription-on-split"; @@ -1274,38 +1307,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] - #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] - fn pub_sub_conn_reuse(#[case] runtime: RuntimeType) { - let ctx = TestContext::new(); - block_on_all( - async move { - #[allow(deprecated)] - let mut pubsub_conn = ctx.deprecated_async_connection().await?.into_pubsub(); - pubsub_conn.subscribe("phonewave").await?; - pubsub_conn.psubscribe("*").await?; - - #[allow(deprecated)] - let mut conn = pubsub_conn.into_connection().await; - redis::cmd("SET") - .arg("foo") - .arg("bar") - .exec_async(&mut conn) - .await?; - - let res: String = redis::cmd("GET").arg("foo").query_async(&mut conn).await?; - assert_eq!(&res, "bar"); - - Ok::<_, RedisError>(()) - }, - runtime, - ) - .unwrap(); - } - - #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn pipe_errors_do_not_affect_subsequent_commands(#[case] runtime: RuntimeType) { test_with_all_connection_types( |mut conn| async move { @@ -1328,8 +1332,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn multiplexed_pub_sub_subscribe_on_multiple_channels(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); if ctx.protocol == ProtocolVersion::RESP2 { @@ -1366,8 +1371,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn non_transaction_errors_do_not_affect_other_results_in_pipeline( #[case] runtime: RuntimeType, ) { @@ -1396,8 +1402,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn pub_sub_multiple(#[case] runtime: RuntimeType) { use redis::RedisError; @@ -1471,8 +1478,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn pub_sub_requires_resp3(#[case] runtime: RuntimeType) { if use_protocol() != ProtocolVersion::RESP2 { return; @@ -1493,8 +1501,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn push_sender_send_on_disconnect(#[case] runtime: RuntimeType) { use redis::RedisError; @@ -1526,7 +1535,7 @@ mod basic_async { #[cfg(feature = "connection-manager")] #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[case::async_std(RuntimeType::AsyncStd)] fn manager_should_resubscribe_to_pubsub_channels_after_disconnect( #[case] runtime: RuntimeType, @@ -1657,8 +1666,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_basic_pipe_with_parsing_error(#[case] runtime: RuntimeType) { // Tests a specific case involving repeated errors in transactions. test_with_all_connection_types( @@ -1694,8 +1704,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] #[cfg(feature = "connection-manager")] fn test_connection_manager_reconnect_after_delay(#[case] runtime: RuntimeType) { let max_delay_between_attempts = 2; @@ -1754,7 +1765,7 @@ mod basic_async { #[cfg(feature = "connection-manager")] #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[case::async_std(RuntimeType::AsyncStd)] fn manager_should_reconnect_without_actions_if_push_sender_is_set( #[case] runtime: RuntimeType, @@ -1795,8 +1806,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_multiplexed_connection_kills_connection_on_drop_even_when_blocking( #[case] runtime: RuntimeType, ) { @@ -1844,13 +1856,42 @@ mod basic_async { .unwrap(); } + #[rstest] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] + #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] + fn test_monitor(#[case] runtime: RuntimeType) { + let ctx = TestContext::new(); + block_on_all( + async move { + let mut conn = ctx.async_connection().await.unwrap(); + let monitor_conn = ctx.client.get_async_monitor().await.unwrap(); + let mut stream = monitor_conn.into_on_message(); + + let _: () = conn.set("foo", "bar").await?; + + let msg: String = stream.next().await.unwrap(); + assert!(msg.ends_with("\"SET\" \"foo\" \"bar\"")); + + drop(ctx); + + assert!(stream.next().await.is_none()); + + Ok(()) + }, + runtime, + ) + .unwrap(); + } + #[cfg(feature = "tls-rustls")] mod mtls_test { use super::*; #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_should_connect_mtls(#[case] runtime: RuntimeType) { let ctx = TestContext::new_with_mtls(); @@ -1875,8 +1916,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_should_not_connect_if_tls_active(#[case] runtime: RuntimeType) { let ctx = TestContext::new_with_mtls(); @@ -1918,8 +1960,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] #[cfg(feature = "connection-manager")] fn test_resp3_pushes_connection_manager(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); @@ -1960,8 +2003,9 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_select_db(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); let mut connection_info = ctx.client.get_connection_info().clone(); @@ -1986,7 +2030,7 @@ mod basic_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_multiplexed_connection_send_single_disconnect_on_connection_failure( #[case] runtime: RuntimeType, diff --git a/redis/tests/test_basic.rs b/redis/tests/test_basic.rs index acec3f400..a3105302e 100644 --- a/redis/tests/test_basic.rs +++ b/redis/tests/test_basic.rs @@ -2,14 +2,33 @@ mod support; +macro_rules! run_test_if_version_supported { + ($minimum_required_version:expr) => {{ + let ctx = TestContext::new(); + let redis_version = ctx.get_version(); + + if redis_version < *$minimum_required_version { + eprintln!("Skipping the test because the current version of Redis {:?} doesn't match the minimum required version {:?}.", + redis_version, $minimum_required_version); + return; + } + + ctx + }}; +} + #[cfg(test)] mod basic { use assert_approx_eq::assert_approx_eq; - use redis::{cmd, Client, ProtocolVersion, PushInfo, RedisConnectionInfo, Role, ScanOptions}; + use rand::distr::Alphanumeric; + use rand::{rng, Rng}; + use redis::{ + cmd, Client, Connection, ProtocolVersion, PushInfo, RedisConnectionInfo, Role, ScanOptions, + }; use redis::{ Commands, ConnectionInfo, ConnectionLike, ControlFlow, ErrorKind, ExistenceCheck, - ExpireOption, Expiry, PubSubCommands, PushKind, RedisResult, SetExpiry, SetOptions, - ToRedisArgs, Value, + ExpireOption, Expiry, FieldExistenceCheck, HashFieldExpirationOptions, PubSubCommands, + PushKind, RedisResult, SetExpiry, SetOptions, ToRedisArgs, Value, }; use redis_test::utils::get_listener_on_free_port; use std::collections::{BTreeMap, BTreeSet}; @@ -21,6 +40,37 @@ mod basic { use crate::{assert_args, support::*}; + const REDIS_VERSION_CE_8_0_RC1: (u16, u16, u16) = (7, 9, 240); + const HASH_KEY: &str = "testing_hash"; + const HASH_FIELDS_AND_VALUES: [(&str, u8); 5] = + [("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8), ("f5", 16)]; + const FIELD_EXISTS_WITHOUT_TTL: i8 = -1; + + /// Generates a unique hash key that does not already exist. + fn generate_random_testing_hash_key(con: &mut Connection) -> String { + const TEST_HASH_KEY_BASE: &str = "testing_hash"; + const TEST_HASH_KEY_RANDOM_LENGTH: usize = 7; + + loop { + let generated_hash_key = format!( + "{}_{}", + TEST_HASH_KEY_BASE, + rng() + .sample_iter(&Alphanumeric) + .take(TEST_HASH_KEY_RANDOM_LENGTH) + .map(char::from) + .collect::() + ); + + let hash_exists: bool = con.exists(&generated_hash_key).unwrap(); + + if !hash_exists { + println!("Generated random testing hash key: {}", &generated_hash_key); + return generated_hash_key; + } + } + } + #[test] fn test_parse_redis_url() { let redis_url = "redis://127.0.0.1:1234/0".to_string(); @@ -404,6 +454,583 @@ mod basic { assert_eq!(con.unlink(&["foo"]), Ok(1)); } + /// Verify that the hash contains exactly the specified fields with their corresponding values. + fn verify_exact_hash_fields_and_values( + con: &mut Connection, + hash_key: &str, + hash_fields_and_values: &[(&str, u8)], + ) { + let hash_fields: HashMap = con.hgetall(hash_key).unwrap(); + assert_eq!(hash_fields.len(), hash_fields_and_values.len()); + + for (field, value) in hash_fields_and_values { + assert_eq!(hash_fields.get(*field), Some(value)); + } + } + + #[inline(always)] + fn verify_fields_absence_from_hash( + hash_fields: &HashMap, + hash_fields_to_check: &[&str], + ) { + hash_fields_to_check.iter().for_each(|key| { + assert!(!hash_fields.contains_key(*key)); + }); + } + + /// The test validates the following scenarios for the HGETDEL command: + /// + /// 1. It successfully deletes a single field from a given existing hash. + /// 2. Attempting to delete a non-existing field from a given existing hash results in a NIL response. + /// 3. It successfully deletes multiple fields from a given existing hash. + /// 4. When used on a hash with only one field, it deletes the entire hash. + /// 5. Attempting to delete a field from a non-existing hash results in a NIL response. + #[test] + fn test_hget_del() { + let ctx = run_test_if_version_supported!(&REDIS_VERSION_CE_8_0_RC1); + let mut con = ctx.connection(); + // Create a hash with multiple fields and values that will be used for testing + assert_eq!(con.hset_multiple(HASH_KEY, &HASH_FIELDS_AND_VALUES), Ok(())); + + // Delete the first field + assert_eq!( + con.hget_del(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0), + Ok([HASH_FIELDS_AND_VALUES[0].1]) + ); + + let mut removed_fields = Vec::from([HASH_FIELDS_AND_VALUES[0].0]); + + // Verify that the field has been deleted + let remaining_hash_fields: HashMap = con.hgetall(HASH_KEY).unwrap(); + assert_eq!( + remaining_hash_fields.len(), + HASH_FIELDS_AND_VALUES.len() - removed_fields.len() + ); + verify_fields_absence_from_hash(&remaining_hash_fields, &removed_fields); + + // Verify that a non-existing field returns NIL by attempting to delete the same field again + assert_eq!(con.hget_del(HASH_KEY, &removed_fields), Ok([Value::Nil])); + + // Prepare additional fields for deletion + let fields_to_delete = [ + HASH_FIELDS_AND_VALUES[1].0, + HASH_FIELDS_AND_VALUES[2].0, + HASH_FIELDS_AND_VALUES[3].0, + ]; + + // Delete the additional fields + assert_eq!( + con.hget_del(HASH_KEY, &fields_to_delete), + Ok([ + HASH_FIELDS_AND_VALUES[1].1, + HASH_FIELDS_AND_VALUES[2].1, + HASH_FIELDS_AND_VALUES[3].1 + ]) + ); + + removed_fields.extend_from_slice(&fields_to_delete); + + // Verify that all of the fields have been deleted + let remaining_hash_fields: HashMap = con.hgetall(HASH_KEY).unwrap(); + assert_eq!( + remaining_hash_fields.len(), + HASH_FIELDS_AND_VALUES.len() - removed_fields.len() + ); + verify_fields_absence_from_hash(&remaining_hash_fields, &removed_fields); + + // Verify that removing the last field deletes the hash + assert_eq!( + con.hget_del(HASH_KEY, HASH_FIELDS_AND_VALUES[4].0), + Ok([HASH_FIELDS_AND_VALUES[4].1]) + ); + assert_eq!(con.exists(HASH_KEY), Ok(false)); + + // Verify that HGETDEL on a non-existing hash returns NIL + assert_eq!( + con.hget_del(HASH_KEY, HASH_FIELDS_AND_VALUES[4].0), + Ok([Value::Nil]) + ); + } + + /// The test validates the following scenarios for the HGETEX command: + /// + /// 1. It successfully retrieves a single field from a given existing hash without setting its expiration. + /// 2. It successfully retrieves multiple fields from a given existing hash without setting their expiration. + /// 3. It successfully retrieves a single field from a given existing hash and sets its expiration to 1 second. + /// It verifies that the field has been set to expire and that it is no longer present in the hash after it expires. + /// 4. Attempting to retrieve a non-existing field from a given existing hash returns in a NIL response. + /// 5. It successfully retrieves multiple fields from a given existing hash and sets their expiration to 1 second. + /// It verifies that the fields have been set to expire and that they are no longer present in the hash after they expire. + /// 6. Attempting to retrieve a field from a non-existing hash returns in a NIL response. + #[test] + fn test_hget_ex() { + let ctx = run_test_if_version_supported!(&REDIS_VERSION_CE_8_0_RC1); + let mut con = ctx.connection(); + // Create a hash with multiple fields and values that will be used for testing + assert_eq!(con.hset_multiple(HASH_KEY, &HASH_FIELDS_AND_VALUES), Ok(())); + + // Scenario 1 + // Retrieve a single field without setting its expiration + assert_eq!( + con.hget_ex(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0, Expiry::PERSIST), + Ok([HASH_FIELDS_AND_VALUES[0].1]) + ); + assert_eq!( + con.httl(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0), + Ok([FIELD_EXISTS_WITHOUT_TTL]) + ); + + // Scenario 2 + // Retrieve multiple fields at once without setting their expiration + let fields_to_retrieve = [HASH_FIELDS_AND_VALUES[1].0, HASH_FIELDS_AND_VALUES[2].0]; + assert_eq!( + con.hget_ex(HASH_KEY, &fields_to_retrieve, Expiry::PERSIST), + Ok([HASH_FIELDS_AND_VALUES[1].1, HASH_FIELDS_AND_VALUES[2].1]) + ); + assert_eq!( + con.httl(HASH_KEY, &fields_to_retrieve), + Ok([FIELD_EXISTS_WITHOUT_TTL, FIELD_EXISTS_WITHOUT_TTL]) + ); + + // Scenario 3 + // Retrieve a single field and set its expiration to 1 second + assert_eq!( + con.hget_ex(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0, Expiry::EX(1)), + Ok([HASH_FIELDS_AND_VALUES[0].1]) + ); + // Verify that the all fields are still present in the hash + verify_exact_hash_fields_and_values(&mut con, HASH_KEY, &HASH_FIELDS_AND_VALUES); + // Verify that the field has been set to expire + assert_eq!(con.httl(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0), Ok([1])); + // Wait for the field to expire and verify it + sleep(Duration::from_millis(1100)); + + let mut expired_fields = Vec::from([HASH_FIELDS_AND_VALUES[0].0]); + + let remaining_hash_fields: HashMap = con.hgetall(HASH_KEY).unwrap(); + assert_eq!( + remaining_hash_fields.len(), + HASH_FIELDS_AND_VALUES.len() - expired_fields.len() + ); + verify_fields_absence_from_hash(&remaining_hash_fields, &expired_fields); + + // Scenario 4 + // Verify that a non-existing field returns NIL by attempting to retrieve it with HGETEX + assert_eq!( + con.hget_ex(HASH_KEY, &expired_fields, Expiry::PERSIST), + Ok([Value::Nil]) + ); + + // Scenario 5 + // Retrieve multiple fields and set their expiration to 1 second + let fields_to_expire = [ + HASH_FIELDS_AND_VALUES[1].0, + HASH_FIELDS_AND_VALUES[2].0, + HASH_FIELDS_AND_VALUES[3].0, + HASH_FIELDS_AND_VALUES[4].0, + ]; + let hash_field_values: Vec = con + .hget_ex(HASH_KEY, &fields_to_expire, Expiry::EX(1)) + .unwrap(); + assert_eq!(hash_field_values.len(), fields_to_expire.len()); + + for i in 0..fields_to_expire.len() { + assert_eq!(hash_field_values[i], HASH_FIELDS_AND_VALUES[i + 1].1); + } + // Verify that all fields, except the first one, which has already expired, are still present in the hash + verify_exact_hash_fields_and_values(&mut con, HASH_KEY, &HASH_FIELDS_AND_VALUES[1..]); + // Verify that the fields have been set to expire + assert_eq!( + con.httl(HASH_KEY, &fields_to_expire), + Ok(vec![1; fields_to_expire.len()]) + ); + // Wait for the fields to expire and verify it + sleep(Duration::from_millis(1100)); + + expired_fields.extend_from_slice(&fields_to_expire); + + let remaining_hash_fields: HashMap = con.hgetall(HASH_KEY).unwrap(); + assert_eq!( + remaining_hash_fields.len(), + HASH_FIELDS_AND_VALUES.len() - expired_fields.len() + ); + verify_fields_absence_from_hash(&remaining_hash_fields, &expired_fields); + + // Scenario 6 + // Verify that HGETEX on a non-existing hash returns NIL + assert_eq!(con.exists(HASH_KEY), Ok(false)); + assert_eq!( + con.hget_ex(HASH_KEY, &expired_fields, Expiry::PERSIST), + Ok(vec![Value::Nil; expired_fields.len()]) + ); + } + + /// The test validates the various expiration options for hash fields using the HGETEX command. + /// + /// It tests setting expiration using the EX, PX, EXAT, and PXAT options, + /// as well as removing an existing expiration using the PERSIST option. + #[test] + fn test_hget_ex_field_expiration_options() { + let ctx = run_test_if_version_supported!(&REDIS_VERSION_CE_8_0_RC1); + let mut con = ctx.connection(); + // Create a hash with multiple fields and values that will be used for testing + assert_eq!(con.hset_multiple(HASH_KEY, &HASH_FIELDS_AND_VALUES), Ok(())); + + // Verify that initially all fields are present in the hash + verify_exact_hash_fields_and_values(&mut con, HASH_KEY, &HASH_FIELDS_AND_VALUES); + + // Set the fields to expire in 1 second using different expiration options + assert_eq!( + con.hget_ex(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0, Expiry::EX(1)), + Ok([HASH_FIELDS_AND_VALUES[0].1]) + ); + assert_eq!( + con.hget_ex(HASH_KEY, HASH_FIELDS_AND_VALUES[1].0, Expiry::PX(1000)), + Ok([HASH_FIELDS_AND_VALUES[1].1]) + ); + let current_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + assert_eq!( + con.hget_ex( + HASH_KEY, + HASH_FIELDS_AND_VALUES[2].0, + Expiry::EXAT(current_timestamp.as_secs() + 1) + ), + Ok([HASH_FIELDS_AND_VALUES[2].1]) + ); + assert_eq!( + con.hget_ex( + HASH_KEY, + HASH_FIELDS_AND_VALUES[3].0, + Expiry::PXAT(current_timestamp.as_millis() as u64 + 1000) + ), + Ok([HASH_FIELDS_AND_VALUES[3].1]) + ); + assert_eq!( + con.hget_ex(HASH_KEY, HASH_FIELDS_AND_VALUES[4].0, Expiry::EX(1)), + Ok([HASH_FIELDS_AND_VALUES[4].1]) + ); + // Remove the expiration from the last field + assert_eq!( + con.hget_ex(HASH_KEY, HASH_FIELDS_AND_VALUES[4].0, Expiry::PERSIST), + Ok([HASH_FIELDS_AND_VALUES[4].1]) + ); + + // Wait for the fields to expire and verify that only the last field remains in the hash + sleep(Duration::from_millis(1100)); + verify_exact_hash_fields_and_values(&mut con, HASH_KEY, &[HASH_FIELDS_AND_VALUES[4]]); + + // Remove the hash + assert_eq!(con.del(HASH_KEY), Ok(1)); + } + + /// The test validates the following scenarios for the HSETEX command: + /// + /// Tests the behavior of HSETEX with different field existence checks (FNX and FXX): + /// 1. (FNX) successfully sets fields in a hash that does not exist and creates the hash. + /// 2. (FNX) fails to set fields in a hash that already contains one or more of the fields. + /// 3. (FXX) fails to set fields in a hash that does not have one or more of the fields. + /// + /// Tests the behavior of HSETEX with and without expiration: + /// + /// Note: All of the following tests, use FXX and operate on existing fields. + /// + /// 4. It successfully sets a single field without setting its expiration + /// and verifies that the value has been modified, but no expiration is set. + /// 5. It successfully sets multiple fields without setting their expiration + /// and verifies that their values have been modified, but no expiration is set. + /// 6. It successfully sets a single field with an expiration + /// and verifies that the value has been modified and the field is set to expire. + /// 7. It successfully sets all fields with an expiration + /// and verifies that their values have been modified and the fields are set to expire. + #[test] + fn test_hset_ex() { + let ctx = run_test_if_version_supported!(&REDIS_VERSION_CE_8_0_RC1); + let mut con = ctx.connection(); + + let generated_hash_key = generate_random_testing_hash_key(&mut con); + + let hfe_options = + HashFieldExpirationOptions::default().set_existence_check(FieldExistenceCheck::FNX); + + // Scenario 1 + // Verify that HSETEX with FNX on a hash that does not exist succeeds and creates the hash with the specified fields and values + let fields_set_successfully: bool = con + .hset_ex(&generated_hash_key, &hfe_options, &HASH_FIELDS_AND_VALUES) + .unwrap(); + assert!(fields_set_successfully); + + // Verify that the hash has been created with the expected fields and values + verify_exact_hash_fields_and_values(&mut con, &generated_hash_key, &HASH_FIELDS_AND_VALUES); + + // Scenario 2 + // Executing HSETEX with FNX on a hash that already contains a field should fail + let fields_and_values_for_update = [HASH_FIELDS_AND_VALUES[0], ("NonExistingField", 1)]; + + let field_set_successfully: bool = con + .hset_ex( + &generated_hash_key, + &hfe_options, + &fields_and_values_for_update, + ) + .unwrap(); + assert!(!field_set_successfully); + + // Verify that the hash consists of its original fields + verify_exact_hash_fields_and_values(&mut con, &generated_hash_key, &HASH_FIELDS_AND_VALUES); + + // Scenario 3 + // Executing HSETEX with FXX on a hash that does not have one or more of the fields should fail + let hfe_options = hfe_options.set_existence_check(FieldExistenceCheck::FXX); + + let field_set_successfully: bool = con + .hset_ex( + &generated_hash_key, + &hfe_options, + &fields_and_values_for_update, + ) + .unwrap(); + assert!(!field_set_successfully); + + // Verify that the hash consists of its original fields + verify_exact_hash_fields_and_values(&mut con, &generated_hash_key, &HASH_FIELDS_AND_VALUES); + + // Scenario 4 + // Use the HSETEX command to double the value of the first field without setting its expiration + let initial_fields = HASH_FIELDS_AND_VALUES.map(|(key, _)| key); + + let first_field_with_doubled_value = + [(HASH_FIELDS_AND_VALUES[0].0, HASH_FIELDS_AND_VALUES[0].1 * 2)]; + let field_set_successfully: bool = con + .hset_ex( + &generated_hash_key, + &hfe_options, + &first_field_with_doubled_value, + ) + .unwrap(); + assert!(field_set_successfully); + + // Verify that the field's value has been set to the new value + let hash_fields: HashMap = con.hgetall(&generated_hash_key).unwrap(); + assert_eq!(hash_fields.len(), initial_fields.len()); + assert_eq!( + hash_fields[first_field_with_doubled_value[0].0], + first_field_with_doubled_value[0].1 + ); + + // Verify that the field is not set to expire + assert_eq!( + con.httl(&generated_hash_key, first_field_with_doubled_value[0].0), + Ok([FIELD_EXISTS_WITHOUT_TTL]) + ); + + // Scenario 5 + // Use the HSETEX command to double the original values of all fields without setting their expiration + let fields_with_doubled_values: Vec<(&str, u8)> = HASH_FIELDS_AND_VALUES + .iter() + .map(|(field, value)| (*field, value * 2)) + .collect(); + let fields_set_successfully: bool = con + .hset_ex( + &generated_hash_key, + &hfe_options, + &fields_with_doubled_values, + ) + .unwrap(); + assert!(fields_set_successfully); + + // Verify that the values of the fields have been set to the new values. + verify_exact_hash_fields_and_values( + &mut con, + &generated_hash_key, + &fields_with_doubled_values, + ); + + // Verify that the fields are not set to expire + assert_eq!( + con.httl(&generated_hash_key, &initial_fields), + Ok(vec![FIELD_EXISTS_WITHOUT_TTL; initial_fields.len()]) + ); + + // Scenario 6 + // Use the HSETEX command to triple the original value of the first field and set its expiration to 10 seconds + let hfe_options = hfe_options.set_expiration(SetExpiry::EX(10)); + let first_field_with_tripled_value = + [(HASH_FIELDS_AND_VALUES[0].0, HASH_FIELDS_AND_VALUES[0].1 * 3)]; + let field_set_successfully: bool = con + .hset_ex( + &generated_hash_key, + &hfe_options, + &first_field_with_tripled_value, + ) + .unwrap(); + assert!(field_set_successfully); + + // Verify that the field's value has been set to the new value + let hash_fields: HashMap = con.hgetall(&generated_hash_key).unwrap(); + assert_eq!(hash_fields.len(), initial_fields.len()); + assert_eq!( + hash_fields[first_field_with_tripled_value[0].0], + first_field_with_tripled_value[0].1 + ); + + // Verify that the field was set to expire + assert_eq!( + con.httl(&generated_hash_key, first_field_with_tripled_value[0].0), + Ok([10]) + ); + + // Scenario 7 + // Use the HSETEX command to triple the values of all initial fields and set their expiration to 1 second + let hfe_options = hfe_options.set_expiration(SetExpiry::EX(1)); + let fields_with_tripled_values: Vec<(&str, u8)> = HASH_FIELDS_AND_VALUES + .iter() + .map(|(field, value)| (*field, value * 3)) + .collect(); + let fields_set_successfully: bool = con + .hset_ex( + &generated_hash_key, + &hfe_options, + &fields_with_tripled_values, + ) + .unwrap(); + assert!(fields_set_successfully); + + // Verify that the fields' values have been set to the new values + verify_exact_hash_fields_and_values( + &mut con, + &generated_hash_key, + &fields_with_tripled_values, + ); + + // Verify that the fields were set to expire + assert_eq!( + con.httl(&generated_hash_key, &initial_fields), + Ok(vec![1; initial_fields.len()]) + ); + + // Wait for the fields to expire + sleep(Duration::from_millis(1100)); + + // Verify that the fields have expired + let remaining_hash_fields: HashMap = + con.hgetall(&generated_hash_key).unwrap(); + assert_eq!( + remaining_hash_fields.len(), + HASH_FIELDS_AND_VALUES.len() - initial_fields.len() + ); + verify_fields_absence_from_hash(&remaining_hash_fields, &initial_fields); + } + + /// The test validates the various expiration options for hash fields using the HSETEX command. + /// + /// It tests setting expiration using the EX, PX, EXAT, and PXAT options, + /// as well as keeping an existing expiration using the KEEPTTL option. + #[test] + fn test_hsetex_field_expiration_options() { + let ctx = run_test_if_version_supported!(&REDIS_VERSION_CE_8_0_RC1); + let mut con = ctx.connection(); + // Create a hash with multiple fields and values that will be used for testing + assert_eq!(con.hset_multiple(HASH_KEY, &HASH_FIELDS_AND_VALUES), Ok(())); + + // Verify that initially all fields are present in the hash + verify_exact_hash_fields_and_values(&mut con, HASH_KEY, &HASH_FIELDS_AND_VALUES); + + // Set the fields to expire in 1 second using different expiration options + let hfe_options = HashFieldExpirationOptions::default() + .set_existence_check(FieldExistenceCheck::FXX) + .set_expiration(SetExpiry::EX(1)); + + let expiration_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[0]]) + .unwrap(); + assert!(expiration_set_successfully); + + let hfe_options = hfe_options.set_expiration(SetExpiry::PX(1000)); + let expiration_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[1]]) + .unwrap(); + assert!(expiration_set_successfully); + + let current_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + + let hfe_options = + hfe_options.set_expiration(SetExpiry::EXAT(current_timestamp.as_secs() + 1)); + let expiration_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[2]]) + .unwrap(); + assert!(expiration_set_successfully); + + let hfe_options = hfe_options + .set_expiration(SetExpiry::PXAT(current_timestamp.as_millis() as u64 + 1000)); + let expiration_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[3]]) + .unwrap(); + assert!(expiration_set_successfully); + + let hfe_options = hfe_options.set_expiration(SetExpiry::EX(1)); + let expiration_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[4]]) + .unwrap(); + assert!(expiration_set_successfully); + + // Using KEEPTTL will preserve the 1 second set above + let hfe_options = hfe_options.set_expiration(SetExpiry::KEEPTTL); + let expiration_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[4]]) + .unwrap(); + assert!(expiration_set_successfully); + + // Wait for the fields to expire and verify it + sleep(Duration::from_millis(1100)); + + // Verify that all fields have expired and the hash no longer exists + assert_eq!(con.exists(HASH_KEY), Ok(false)); + } + + #[test] + fn test_hsetex_can_update_the_expiration_of_a_field_that_has_already_been_set_to_expire() { + let ctx = run_test_if_version_supported!(&REDIS_VERSION_CE_8_0_RC1); + let mut con = ctx.connection(); + // Create a hash with multiple fields and values that will be used for testing + assert_eq!(con.hset_multiple(HASH_KEY, &HASH_FIELDS_AND_VALUES), Ok(())); + + let hfe_options = HashFieldExpirationOptions::default() + .set_existence_check(FieldExistenceCheck::FXX) + .set_expiration(SetExpiry::EX(1)); + + // Use the HSETEX command to set the first field to expire in 1 second + let field_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[0]]) + .unwrap(); + assert!(field_set_successfully); + + // Use the HSETEX command again to set the timeout to 2 seconds + let hfe_options = hfe_options.set_expiration(SetExpiry::EX(2)); + let field_set_successfully: bool = con + .hset_ex(HASH_KEY, &hfe_options, &[HASH_FIELDS_AND_VALUES[0]]) + .unwrap(); + assert!(field_set_successfully); + // Verify that all of the fields still have their initial values + verify_exact_hash_fields_and_values(&mut con, HASH_KEY, &HASH_FIELDS_AND_VALUES); + + // Verify that the field was set to expire + assert_eq!(con.httl(HASH_KEY, HASH_FIELDS_AND_VALUES[0].0), Ok([2])); + + // Wait for the field to expire + sleep(Duration::from_millis(2100)); + + // Verify that the field has expired + let remaining_hash_fields: HashMap = con.hgetall(HASH_KEY).unwrap(); + assert_eq!( + remaining_hash_fields.len(), + HASH_FIELDS_AND_VALUES.len() - 1 + ); + verify_fields_absence_from_hash(&remaining_hash_fields, &[HASH_FIELDS_AND_VALUES[0].0]); + + // Remove the hash + assert_eq!(con.del(HASH_KEY), Ok(1)); + } + // Requires redis-server >= 4.0.0. // Not supported with the current appveyor/windows binary deployed. #[cfg(not(target_os = "windows"))] @@ -769,6 +1396,23 @@ mod basic { assert_eq!(k2, 45); } + #[test] + fn test_pipeline_len() { + let mut pl = redis::pipe(); + + pl.cmd("PING").cmd("SET").arg(1); + assert_eq!(pl.len(), 2); + } + + #[test] + fn test_pipeline_is_empty() { + let mut pl = redis::pipe(); + + assert!(pl.is_empty()); + pl.cmd("PING").cmd("SET").arg(1); + assert!(!pl.is_empty()); + } + #[test] fn test_real_transaction() { let ctx = TestContext::new(); diff --git a/redis/tests/test_cache.rs b/redis/tests/test_cache.rs index fa4b9f9f6..e892ed0ac 100644 --- a/redis/tests/test_cache.rs +++ b/redis/tests/test_cache.rs @@ -2,7 +2,6 @@ use crate::support::*; use futures_time::task::sleep; -use redis::aio::MultiplexedConnection; use redis::CommandCacheConfig; use redis::{caching::CacheConfig, AsyncCommands, ProtocolVersion, RedisError}; use redis_test::server::Module; @@ -13,9 +12,27 @@ use std::time::Duration; mod support; +macro_rules! assert_hit { + ($con:expr, $val:expr) => { + assert_eq!($con.get_cache_statistics().unwrap().hit, $val); + }; +} + +macro_rules! assert_miss { + ($con:expr, $val:expr) => { + assert_eq!($con.get_cache_statistics().unwrap().miss, $val); + }; +} + +macro_rules! assert_invalidate { + ($con:expr, $val:expr) => { + assert_eq!($con.get_cache_statistics().unwrap().invalidate, $val); + }; +} + // Basic testing should work with both CacheMode::All and CacheMode::OptIn if commands has called cache() #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_basic(#[case] runtime: RuntimeType, #[values(true, false)] test_with_optin: bool) { let ctx = TestContext::new(); @@ -36,8 +53,8 @@ fn test_cache_basic(#[case] runtime: RuntimeType, #[values(true, false)] test_wi .await .unwrap(); assert_eq!(val, None); - assert_hit(&con, 0); - assert_miss(&con, 1); + assert_hit!(&con, 0); + assert_miss!(&con, 1); let val: Option = get_cmd("GET", test_with_optin) .arg("key_1") @@ -46,8 +63,8 @@ fn test_cache_basic(#[case] runtime: RuntimeType, #[values(true, false)] test_wi .unwrap(); assert_eq!(val, None); // key_1's value should be returned from cache even if it doesn't exist in server yet. - assert_hit(&con, 1); - assert_miss(&con, 1); + assert_hit!(&con, 1); + assert_miss!(&con, 1); let _: () = get_cmd("SET", test_with_optin) .arg("key_1") @@ -56,9 +73,9 @@ fn test_cache_basic(#[case] runtime: RuntimeType, #[values(true, false)] test_wi .await .unwrap(); sleep(Duration::from_millis(50).into()).await; // Give time for push message to be received after invalidating key_1. - assert_hit(&con, 1); - assert_miss(&con, 1); - assert_invalidate(&con, 1); + assert_hit!(&con, 1); + assert_miss!(&con, 1); + assert_invalidate!(&con, 1); let val: String = get_cmd("GET", test_with_optin) .arg("key_1") @@ -67,9 +84,9 @@ fn test_cache_basic(#[case] runtime: RuntimeType, #[values(true, false)] test_wi .unwrap(); assert_eq!(val, "1"); // After invalidating key_1, now it misses the key from cache - assert_hit(&con, 1); - assert_miss(&con, 2); - assert_invalidate(&con, 1); + assert_hit!(&con, 1); + assert_miss!(&con, 2); + assert_invalidate!(&con, 1); Ok::<_, RedisError>(()) }, runtime, @@ -78,7 +95,7 @@ fn test_cache_basic(#[case] runtime: RuntimeType, #[values(true, false)] test_wi } #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_mget(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); @@ -107,8 +124,8 @@ fn test_cache_mget(#[case] runtime: RuntimeType) { .arg("key_2") .query_async(&mut con) .await?; - assert_hit(&con, 0); - assert_miss(&con, 2); + assert_hit!(&con, 0); + assert_miss!(&con, 2); assert_eq!(res1, vec![Some("41".to_string()), None]); let res2: Vec> = redis::cmd("MGET") @@ -117,16 +134,16 @@ fn test_cache_mget(#[case] runtime: RuntimeType) { .arg("key_2") .query_async(&mut con) .await?; - assert_hit(&con, 2); - assert_miss(&con, 3); + assert_hit!(&con, 2); + assert_miss!(&con, 3); assert_eq!( res2, vec![Some("41".to_string()), Some("43".to_string()), None] ); let _: Option = redis::cmd("GET").arg("key_1").query_async(&mut con).await?; - assert_hit(&con, 3); - assert_miss(&con, 3); + assert_hit!(&con, 3); + assert_miss!(&con, 3); Ok::<_, RedisError>(()) }, runtime, @@ -136,7 +153,7 @@ fn test_cache_mget(#[case] runtime: RuntimeType) { #[rstest] #[cfg(feature = "json")] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_module_cache_json_get_mget(#[case] runtime: RuntimeType) { let ctx = TestContext::with_modules(&[Module::Json], false); @@ -174,8 +191,8 @@ fn test_module_cache_json_get_mget(#[case] runtime: RuntimeType) { assert_eq!(res1.len(), 2); assert_eq!(res1, vec![Some(value_1.clone()), None]); - assert_hit(&con, 0); - assert_miss(&con, 2); + assert_hit!(&con, 0); + assert_miss!(&con, 2); let res2: Vec> = get_cmd("JSON.MGET", true) .arg("key_1") @@ -187,8 +204,8 @@ fn test_module_cache_json_get_mget(#[case] runtime: RuntimeType) { assert_eq!(res2.len(), 3); assert_eq!(res2, vec![Some(value_1.clone()), Some(value_3), None]); - assert_hit(&con, 2); - assert_miss(&con, 3); + assert_hit!(&con, 2); + assert_miss!(&con, 3); let res3: Option = get_cmd("JSON.GET", true) .arg("key_1") @@ -197,8 +214,8 @@ fn test_module_cache_json_get_mget(#[case] runtime: RuntimeType) { .await?; assert_eq!(res3, Some(value_1.clone())); - assert_hit(&con, 3); - assert_miss(&con, 3); + assert_hit!(&con, 3); + assert_miss!(&con, 3); Ok::<_, RedisError>(()) }, @@ -209,7 +226,7 @@ fn test_module_cache_json_get_mget(#[case] runtime: RuntimeType) { #[rstest] #[cfg(feature = "json")] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) { let ctx = TestContext::with_modules(&[Module::Json], false); @@ -247,8 +264,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) assert_eq!(res1.len(), 2); assert_eq!(res1, vec![Some(41), None]); - assert_hit(&con, 0); - assert_miss(&con, 2); + assert_hit!(&con, 0); + assert_miss!(&con, 2); let res2: Vec> = get_cmd("JSON.MGET", true) .arg(&["key_1", "key_2"]) @@ -258,8 +275,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) assert_eq!(res2.len(), 2); assert_eq!(res2, vec![Some(41), None]); - assert_hit(&con, 2); - assert_miss(&con, 2); + assert_hit!(&con, 2); + assert_miss!(&con, 2); let res3: Vec> = get_cmd("JSON.MGET", true) .arg(&["key_1", "key_3"]) @@ -269,8 +286,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) assert_eq!(res3.len(), 2); assert_eq!(res3, vec![Some(1), Some(3)]); - assert_hit(&con, 2); - assert_miss(&con, 4); + assert_hit!(&con, 2); + assert_miss!(&con, 4); let res4: Vec> = get_cmd("JSON.MGET", true) .arg(&["key_2", "key_3"]) @@ -280,8 +297,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) assert_eq!(res4.len(), 2); assert_eq!(res4, vec![None, Some(3)]); - assert_hit(&con, 3); - assert_miss(&con, 5); + assert_hit!(&con, 3); + assert_miss!(&con, 5); let res5: Option = get_cmd("JSON.GET", true) .arg("key_1") @@ -290,8 +307,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) .await?; assert_eq!(res5, Some(41)); - assert_hit(&con, 4); - assert_miss(&con, 5); + assert_hit!(&con, 4); + assert_miss!(&con, 5); let res6: Option = get_cmd("JSON.GET", true) .arg("key_1") @@ -300,8 +317,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) .await?; assert_eq!(res6, Some(41)); - assert_hit(&con, 5); - assert_miss(&con, 5); + assert_hit!(&con, 5); + assert_miss!(&con, 5); let res7: Option = get_cmd("JSON.GET", true) .arg("key_1") @@ -310,8 +327,8 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) .await?; assert_eq!(res7, Some(1)); - assert_hit(&con, 6); - assert_miss(&con, 5); + assert_hit!(&con, 6); + assert_miss!(&con, 5); Ok::<_, RedisError>(()) }, @@ -321,7 +338,7 @@ fn test_module_cache_json_get_mget_different_paths(#[case] runtime: RuntimeType) } #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_is_not_target_type_dependent(#[case] runtime: RuntimeType) { let ctx = TestContext::new(); @@ -339,8 +356,8 @@ fn test_cache_is_not_target_type_dependent(#[case] runtime: RuntimeType) { assert_eq!(x, "77"); let x: u8 = con.get("KEY").await?; assert_eq!(x, 77); - assert_hit(&con, 2); - assert_miss(&con, 1); + assert_hit!(&con, 2); + assert_miss!(&con, 1); Ok::<_, RedisError>(()) }, runtime, @@ -349,7 +366,7 @@ fn test_cache_is_not_target_type_dependent(#[case] runtime: RuntimeType) { } #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_with_pipeline(#[case] runtime: RuntimeType, #[values(true, false)] atomic: bool) { let ctx = TestContext::new(); @@ -376,8 +393,8 @@ fn test_cache_with_pipeline(#[case] runtime: RuntimeType, #[values(true, false)] assert_eq!(mget_k1_k2, (41, 42)); // There are 2 miss for key_1, key_2 used with MGET - assert_hit(&con, 0); - assert_miss(&con, 2); + assert_hit!(&con, 0); + assert_miss!(&con, 2); let (k1, mget_k1_k2, k_unknown): (i32, (i32, i32), Option) = get_pipe(atomic) .cmd("GET") @@ -392,8 +409,8 @@ fn test_cache_with_pipeline(#[case] runtime: RuntimeType, #[values(true, false)] assert_eq!(k1, 41); assert_eq!(mget_k1_k2, (41, 42)); assert_eq!(k_unknown, Option::None); - assert_hit(&con, 3); - assert_miss(&con, 3); + assert_hit!(&con, 3); + assert_miss!(&con, 3); Ok::<_, RedisError>(()) }, @@ -403,7 +420,7 @@ fn test_cache_with_pipeline(#[case] runtime: RuntimeType, #[values(true, false)] } #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { // In OptIn mode cache must not be utilized without explicit per command configuration. @@ -422,8 +439,8 @@ fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { .unwrap(); assert_eq!(val, None); // GET is not marked with cache(), there should be no MISS/HIT - assert_hit(&con, 0); - assert_miss(&con, 0); + assert_hit!(&con, 0); + assert_miss!(&con, 0); let _: () = redis::cmd("SET") .arg("key_1") @@ -432,9 +449,9 @@ fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { .await .unwrap(); // There should be no invalidation since cache is not used. - assert_hit(&con, 0); - assert_miss(&con, 0); - assert_invalidate(&con, 0); + assert_hit!(&con, 0); + assert_miss!(&con, 0); + assert_invalidate!(&con, 0); let val: String = redis::cmd("GET") .arg("key_1") @@ -443,8 +460,8 @@ fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { .await .unwrap(); assert_eq!(val, "1"); - assert_hit(&con, 0); - assert_miss(&con, 1); + assert_hit!(&con, 0); + assert_miss!(&con, 1); let val: String = redis::cmd("GET") .arg("key_1") @@ -453,8 +470,8 @@ fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { .unwrap(); assert_eq!(val, "1"); // Since cache is not used, hit should still be 0 - assert_hit(&con, 0); - assert_miss(&con, 1); + assert_hit!(&con, 0); + assert_miss!(&con, 1); let val: String = redis::cmd("GET") .arg("key_1") @@ -463,9 +480,9 @@ fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { .await .unwrap(); assert_eq!(val, "1"); - assert_hit(&con, 1); - assert_miss(&con, 1); - assert_invalidate(&con, 0); + assert_hit!(&con, 1); + assert_miss!(&con, 1); + assert_invalidate!(&con, 0); Ok::<_, RedisError>(()) }, runtime, @@ -474,7 +491,7 @@ fn test_cache_basic_partial_opt_in(#[case] runtime: RuntimeType) { } #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_pipeline_partial_opt_in( #[case] runtime: RuntimeType, @@ -506,8 +523,8 @@ fn test_cache_pipeline_partial_opt_in( assert_eq!(mget_k1_k2, (42, 43)); // Since CacheMode::OptIn is enabled, so there should be no miss or hit - assert_hit(&con, 0); - assert_miss(&con, 0); + assert_hit!(&con, 0); + assert_miss!(&con, 0); for _ in 0..2 { let (mget_k1_k2, k1, k_unknown): ((i32, i32), i32, Option) = get_pipe(atomic) @@ -526,8 +543,8 @@ fn test_cache_pipeline_partial_opt_in( assert_eq!(k_unknown, Option::None); } // Only MGET should be use cache path, since pipeline used twice there should be one miss and one hit. - assert_hit(&con, 2); - assert_miss(&con, 2); + assert_hit!(&con, 2); + assert_miss!(&con, 2); Ok::<_, RedisError>(()) }, @@ -537,7 +554,7 @@ fn test_cache_pipeline_partial_opt_in( } #[rstest] -#[case::tokio(RuntimeType::Tokio)] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn test_cache_different_commands( #[case] runtime: RuntimeType, @@ -570,8 +587,8 @@ fn test_cache_different_commands( .await .unwrap(); assert_eq!(val, 100); - assert_hit(&con, 0); - assert_miss(&con, 1); + assert_hit!(&con, 0); + assert_miss!(&con, 1); let val: Option = get_cmd("HGET", test_with_opt_in) .arg("user") @@ -580,8 +597,8 @@ fn test_cache_different_commands( .await .unwrap(); assert_eq!(val, None); - assert_hit(&con, 0); - assert_miss(&con, 2); + assert_hit!(&con, 0); + assert_miss!(&con, 2); let val: HashMap = get_cmd("HGETALL", test_with_opt_in) .arg("user") @@ -589,8 +606,8 @@ fn test_cache_different_commands( .await .unwrap(); assert_eq!(val.get("health"), Some(100).as_ref()); - assert_hit(&con, 0); - assert_miss(&con, 3); + assert_hit!(&con, 0); + assert_miss!(&con, 3); let val: HashMap = get_cmd("HGETALL", test_with_opt_in) .arg("user") @@ -598,8 +615,8 @@ fn test_cache_different_commands( .await .unwrap(); assert_eq!(val.get("health"), Some(100).as_ref()); - assert_hit(&con, 1); - assert_miss(&con, 3); + assert_hit!(&con, 1); + assert_miss!(&con, 3); Ok::<_, RedisError>(()) }, runtime, @@ -607,6 +624,77 @@ fn test_cache_different_commands( .unwrap(); } +#[rstest] +#[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] +#[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] +#[cfg(feature = "connection-manager")] +fn test_connection_manager_maintains_statistics_after_crashes( + #[case] runtime: RuntimeType, + #[values(true, false)] test_with_optin: bool, +) { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all( + async move { + let cache_config = if test_with_optin { + CacheConfig::new().set_mode(redis::caching::CacheMode::OptIn) + } else { + CacheConfig::default() + }; + + let config = redis::aio::ConnectionManagerConfig::new() + .set_max_delay(1) + .set_cache_config(cache_config); + + let mut get = get_cmd("GET", test_with_optin); + get.arg("key_1"); + + let mut manager = ctx + .client + .get_connection_manager_with_config(config) + .await + .unwrap(); + + let val: Option = get.query_async(&mut manager).await.unwrap(); + assert_eq!(val, None); + assert_hit!(&manager, 0); + assert_miss!(&manager, 1); + + let val: Option = get.query_async(&mut manager).await.unwrap(); + assert_eq!(val, None); + assert_hit!(&manager, 1); + assert_miss!(&manager, 1); + + let addr = ctx.server.client_addr().clone(); + drop(ctx); + + let result: Result = + manager.send_packed_command(&redis::cmd("PING")).await; + assert!(result.unwrap_err().is_unrecoverable_error()); + + let _server = + redis_test::server::RedisServer::new_with_addr_and_modules(addr, &[], false); + + loop { + if manager.send_packed_command(&get).await.is_ok() { + break; + } + } + + assert_eq!(val, None); + // The key should've been invalidated after the disconnect + assert_hit!(&manager, 1); + assert_miss!(&manager, 2); + + Ok(()) + }, + runtime, + ) + .unwrap(); +} + // Support function for testing pipelines fn get_pipe(atomic: bool) -> redis::Pipeline { if atomic { @@ -626,15 +714,3 @@ fn get_cmd(name: &str, enable_opt_in: bool) -> redis::Cmd { } cmd } - -fn assert_hit(con: &MultiplexedConnection, val: usize) { - assert_eq!(con.get_cache_statistics().unwrap().hit, val); -} - -fn assert_miss(con: &MultiplexedConnection, val: usize) { - assert_eq!(con.get_cache_statistics().unwrap().miss, val); -} - -fn assert_invalidate(con: &MultiplexedConnection, val: usize) { - assert_eq!(con.get_cache_statistics().unwrap().invalidate, val); -} diff --git a/redis/tests/test_cluster.rs b/redis/tests/test_cluster.rs index 74e5b6ead..caf716f7b 100644 --- a/redis/tests/test_cluster.rs +++ b/redis/tests/test_cluster.rs @@ -10,7 +10,7 @@ mod cluster { use crate::support::*; use redis::{ - cluster::{cluster_pipe, ClusterClient}, + cluster::{cluster_pipe, ClusterClient, ClusterConnection}, cluster_routing::{MultipleNodeRoutingInfo, RoutingInfo, SingleNodeRoutingInfo}, cmd, parse_redis_value, Commands, ConnectionLike, ErrorKind, ProtocolVersion, RedisError, Value, @@ -20,11 +20,7 @@ mod cluster { server::use_protocol, }; - #[test] - fn test_cluster_basics() { - let cluster = TestClusterContext::new(); - let mut con = cluster.connection(); - + fn smoke_test_connection(mut con: ClusterConnection) { redis::cmd("SET") .arg("{x}key1") .arg(b"foo") @@ -43,6 +39,52 @@ mod cluster { ); } + #[test] + fn test_cluster_basics() { + let cluster = TestClusterContext::new(); + smoke_test_connection(cluster.connection()); + } + + #[cfg(feature = "tls-rustls")] + #[test] + fn test_default_reject_invalid_hostnames() { + use redis_test::cluster::ClusterType; + + if ClusterType::get_intended() != ClusterType::TcpTls { + // Only TLS causes invalid certificates to be rejected as desired. + return; + } + + let cluster = TestClusterContext::new_with_config(RedisClusterConfiguration { + tls_insecure: false, + certs_with_ip_alts: false, + ..Default::default() + }); + assert!(cluster.client.get_connection().is_err()); + } + + #[cfg(feature = "tls-rustls-insecure")] + #[test] + fn test_danger_accept_invalid_hostnames() { + use redis_test::cluster::ClusterType; + + if ClusterType::get_intended() != ClusterType::TcpTls { + // No point testing this TLS-specific mode in non-TLS configurations. + return; + } + + let cluster = TestClusterContext::new_with_config_and_builder( + RedisClusterConfiguration { + tls_insecure: false, + certs_with_ip_alts: false, + ..Default::default() + }, + |builder| builder.danger_accept_invalid_hostnames(true), + ); + + smoke_test_connection(cluster.connection()); + } + #[test] fn test_cluster_with_username_and_password() { let cluster = TestClusterContext::new_with_cluster_client_builder(|builder| { @@ -52,24 +94,7 @@ mod cluster { }); cluster.disable_default_user(); - let mut con = cluster.connection(); - - redis::cmd("SET") - .arg("{x}key1") - .arg(b"foo") - .exec(&mut con) - .unwrap(); - redis::cmd("SET") - .arg(&["{x}key2", "bar"]) - .exec(&mut con) - .unwrap(); - - assert_eq!( - redis::cmd("MGET") - .arg(&["{x}key1", "{x}key2"]) - .query(&mut con), - Ok(("foo".to_string(), b"bar".to_vec())) - ); + smoke_test_connection(cluster.connection()); } #[test] @@ -1023,8 +1048,9 @@ mod cluster { #[test] fn test_cluster_reconnect_after_complete_server_disconnect() { - let cluster = - TestClusterContext::new_with_cluster_client_builder(|builder| builder.retries(3)); + let cluster = TestClusterContext::new_insecure_with_cluster_client_builder(|builder| { + builder.retries(3) + }); let ports: Vec<_> = cluster .nodes @@ -1064,8 +1090,9 @@ mod cluster { #[test] fn test_cluster_reconnect_after_complete_server_disconnect_route_to_many() { - let cluster = - TestClusterContext::new_with_cluster_client_builder(|builder| builder.retries(3)); + let cluster = TestClusterContext::new_insecure_with_cluster_client_builder(|builder| { + builder.retries(3) + }); let ports: Vec<_> = cluster .nodes diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index c66c111da..abcc497cf 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -38,26 +38,33 @@ mod cluster_async { )) } + async fn smoke_test_connection(mut connection: impl redis::aio::ConnectionLike) { + cmd("SET") + .arg("test") + .arg("test_data") + .exec_async(&mut connection) + .await + .expect("SET command should succeed"); + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await + .expect("GET command should succeed"); + assert_eq!(res, "test_data"); + } + #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_basic_cmd(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); block_on_all( async move { - let mut connection = cluster.async_connection().await; - cmd("SET") - .arg("test") - .arg("test_data") - .exec_async(&mut connection) - .await?; - let res: String = cmd("GET") - .arg("test") - .clone() - .query_async(&mut connection) - .await?; - assert_eq!(res, "test_data"); + let connection = cluster.async_connection().await; + smoke_test_connection(connection).await; Ok::<_, RedisError>(()) }, runtime, @@ -66,8 +73,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_basic_eval(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); @@ -92,8 +100,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_basic_script(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); @@ -116,8 +125,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_route_flush_to_specific_node(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); @@ -156,8 +166,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_route_flush_to_node_by_address(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); @@ -202,8 +213,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_route_info_to_nodes(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new_with_config(RedisClusterConfiguration { num_nodes: 12, @@ -290,8 +302,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_cluster_resp3(#[case] runtime: RuntimeType) { if use_protocol() == ProtocolVersion::RESP2 { return; @@ -328,8 +341,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_basic_pipe(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); @@ -352,8 +366,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_multi_shard_commands(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new(); @@ -374,9 +389,70 @@ mod cluster_async { .unwrap() } + #[cfg(feature = "tls-rustls")] + #[rstest] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] + #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + fn test_async_cluster_default_reject_invalid_hostnames(#[case] runtime: RuntimeType) { + use redis_test::cluster::ClusterType; + + if ClusterType::get_intended() != ClusterType::TcpTls { + // Only TLS causes invalid certificates to be rejected as desired. + return; + } + + let cluster = TestClusterContext::new_with_config(RedisClusterConfiguration { + tls_insecure: false, + certs_with_ip_alts: false, + ..Default::default() + }); + + block_on_all( + async move { + assert!(cluster.client.get_async_connection().await.is_err()); + Ok(()) + }, + runtime, + ) + .unwrap(); + } + + #[cfg(feature = "tls-rustls-insecure")] + #[rstest] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] + #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + fn test_async_cluster_danger_accept_invalid_hostnames(#[case] runtime: RuntimeType) { + use redis_test::cluster::ClusterType; + + if ClusterType::get_intended() != ClusterType::TcpTls { + // No point testing this TLS-specific mode in non-TLS configurations. + return; + } + + let cluster = TestClusterContext::new_with_config_and_builder( + RedisClusterConfiguration { + tls_insecure: false, + certs_with_ip_alts: false, + ..Default::default() + }, + |builder| builder.danger_accept_invalid_hostnames(true), + ); + + block_on_all( + async move { + let connection = cluster.async_connection().await; + smoke_test_connection(connection).await; + Ok(()) + }, + runtime, + ) + .unwrap(); + } + #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_basic_failover(#[case] runtime: RuntimeType) { block_on_all( async move { @@ -442,15 +518,10 @@ mod cluster_async { let role: String = info.get("role").expect("cluster role"); if role == "master" { - async { - Ok(redis::Cmd::new() - .arg("FLUSHALL") - .exec_async(&mut conn) - .await?) - } - .timeout(futures_time::time::Duration::from_secs(3)) - .await - .unwrap_or_else(|err| Err(anyhow::Error::from(err)))?; + async { Ok(conn.flushall::<()>().await?) } + .timeout(futures_time::time::Duration::from_secs(3)) + .await + .unwrap_or_else(|err| Err(anyhow::Error::from(err)))?; } node_conns.push(conn); @@ -525,18 +596,15 @@ mod cluster_async { } impl Connect for ErrorConnection { - fn connect<'a, T>( + fn connect_with_config<'a, T>( info: T, - response_timeout: Duration, - connection_timeout: Duration, + config: redis::AsyncConnectionConfig, ) -> RedisFuture<'a, Self> where T: IntoConnectionInfo + Send + 'a, { Box::pin(async move { - let inner = - MultiplexedConnection::connect(info, response_timeout, connection_timeout) - .await?; + let inner = MultiplexedConnection::connect_with_config(info, config).await?; Ok(ErrorConnection { inner }) }) } @@ -1836,10 +1904,11 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_with_username_and_password(#[case] runtime: RuntimeType) { - let cluster = TestClusterContext::new_with_cluster_client_builder(|builder| { + let cluster = TestClusterContext::new_insecure_with_cluster_client_builder(|builder| { builder .username(RedisCluster::username().to_string()) .password(RedisCluster::password().to_string()) @@ -1982,8 +2051,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_handle_complete_server_disconnect_without_panicking( #[case] runtime: RuntimeType, ) { @@ -2013,11 +2083,13 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_reconnect_after_complete_server_disconnect(#[case] runtime: RuntimeType) { - let cluster = - TestClusterContext::new_with_cluster_client_builder(|builder| builder.retries(2)); + let cluster = TestClusterContext::new_insecure_with_cluster_client_builder(|builder| { + builder.retries(2) + }); block_on_all( async move { @@ -2065,13 +2137,15 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_reconnect_after_complete_server_disconnect_route_to_many( #[case] runtime: RuntimeType, ) { - let cluster = - TestClusterContext::new_with_cluster_client_builder(|builder| builder.retries(3)); + let cluster = TestClusterContext::new_insecure_with_cluster_client_builder(|builder| { + builder.retries(3) + }); block_on_all( async move { @@ -2187,8 +2261,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_kill_connection_on_drop_even_when_blocking(#[case] runtime: RuntimeType) { let ctx = TestClusterContext::new_with_cluster_client_builder(|builder| builder.retries(3)); @@ -2407,7 +2482,7 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn pub_sub_subscription(#[case] runtime: RuntimeType) { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -2435,7 +2510,7 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn pub_sub_subscription_with_config(#[case] runtime: RuntimeType) { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -2463,7 +2538,7 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn pub_sub_shardnumsub(#[case] runtime: RuntimeType) { let ctx = TestClusterContext::new_with_cluster_client_builder(|builder| { @@ -2495,7 +2570,7 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn pub_sub_unsubscription(#[case] runtime: RuntimeType) { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -2601,7 +2676,7 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn connection_is_still_usable_if_pubsub_receiver_is_dropped(#[case] runtime: RuntimeType) { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); @@ -2636,7 +2711,7 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn multiple_subscribes_and_unsubscribes_work(#[case] runtime: RuntimeType) { // In this test we subscribe on all subscription variations to 3 channels in a single call, then unsubscribe from 2 channels. @@ -2780,14 +2855,14 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] fn pub_sub_reconnect_after_disconnect(#[case] runtime: RuntimeType) { // in this test we will subscribe to channels, then restart the server, and check that the connection // doesn't send disconnect message, but instead resubscribes automatically. let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); - let ctx = TestClusterContext::new_with_cluster_client_builder(|builder| { + let ctx = TestClusterContext::new_insecure_with_cluster_client_builder(|builder| { builder .use_protocol(ProtocolVersion::RESP3) .push_sender(tx.clone()) @@ -2898,8 +2973,9 @@ mod cluster_async { use super::*; #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_basic_cmd_with_mtls(#[case] runtime: RuntimeType) { let cluster = TestClusterContext::new_with_mtls(); block_on_all( @@ -2925,8 +3001,9 @@ mod cluster_async { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_async_cluster_should_not_connect_without_mtls_enabled( #[case] runtime: RuntimeType, ) { diff --git a/redis/tests/test_sentinel.rs b/redis/tests/test_sentinel.rs index 751130dc5..340edcdbc 100644 --- a/redis/tests/test_sentinel.rs +++ b/redis/tests/test_sentinel.rs @@ -103,25 +103,74 @@ fn assert_connect_to_known_replicas( } } +#[test] +fn test_sentinel_role_no_permission() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection().unwrap(); + + let user: String = redis::cmd("ACL") + .arg("whoami") + .query(&mut master_con) + .unwrap(); + //Remove ROLE permission for the given user on master + let _: () = redis::cmd("ACL") + .arg("SETUSER") + .arg(&user) + .arg("-role") + .query(&mut master_con) + .unwrap(); + + //Remove ROLE permission for the given user on replicas + for _ in 0..number_of_replicas { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection().unwrap(); + let _: () = redis::cmd("ACL") + .arg("SETUSER") + .arg(&user) + .arg("-role") + .query(&mut replica_con) + .unwrap(); + } + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection().unwrap(); + + assert_is_connection_to_master(&mut master_con); +} + #[test] fn test_sentinel_connect_to_random_replica() { + let number_of_replicas = 3; let master_name = "master1"; - let mut context = TestSentinelContext::new(2, 3, 3); - let node_conn_info: SentinelNodeConnectionInfo = context.sentinel_node_connection_info(); - let sentinel = context.sentinel_mut(); + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); let master_client = sentinel .master_for(master_name, Some(&node_conn_info)) .unwrap(); let mut master_con = master_client.get_connection().unwrap(); + assert_is_connection_to_master(&mut master_con); + let mut replica_con = sentinel .replica_for(master_name, Some(&node_conn_info)) .unwrap() .get_connection() .unwrap(); - assert_is_connection_to_master(&mut master_con); assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); } @@ -416,8 +465,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_connect_to_random_replica_async(#[case] runtime: RuntimeType) { let master_name = "master1"; let mut context = TestSentinelContext::new(2, 3, 3); @@ -452,8 +502,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_connect_to_multiple_replicas_async(#[case] runtime: RuntimeType) { let number_of_replicas = 3; let master_name = "master1"; @@ -497,8 +548,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_server_down_async(#[case] runtime: RuntimeType) { let number_of_replicas = 3; let master_name = "master1"; @@ -548,8 +600,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_client_async(#[case] runtime: RuntimeType) { let master_name = "master1"; let mut context = TestSentinelContext::new(2, 3, 3); @@ -600,8 +653,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_client_async_with_connection_timeout(#[case] runtime: RuntimeType) { let master_name = "master1"; let mut context = TestSentinelContext::new(2, 3, 3); @@ -659,8 +713,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_client_async_with_response_timeout(#[case] runtime: RuntimeType) { let master_name = "master1"; let mut context = TestSentinelContext::new(2, 3, 3); @@ -718,8 +773,9 @@ pub mod async_tests { } #[rstest] - #[case::tokio(RuntimeType::Tokio)] + #[cfg_attr(feature = "tokio-comp", case::tokio(RuntimeType::Tokio))] #[cfg_attr(feature = "async-std-comp", case::async_std(RuntimeType::AsyncStd))] + #[cfg_attr(feature = "smol-comp", case::smol(RuntimeType::Smol))] fn test_sentinel_client_async_with_timeouts(#[case] runtime: RuntimeType) { let master_name = "master1"; let mut context = TestSentinelContext::new(2, 3, 3); diff --git a/redis/tests/test_types.rs b/redis/tests/test_types.rs index 81088f49d..3deaa8d62 100644 --- a/redis/tests/test_types.rs +++ b/redis/tests/test_types.rs @@ -1,7 +1,11 @@ mod support; mod types { - use std::{rc::Rc, sync::Arc}; + use std::{ + collections::{HashMap, HashSet}, + rc::Rc, + sync::Arc, + }; use redis::{ErrorKind, FromRedisValue, RedisError, RedisResult, ToRedisArgs, Value}; @@ -845,4 +849,122 @@ mod types { ) } } + + #[test] + fn test_complex_nested_tuples() { + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let value = Value::Array(vec![ + Value::Array(vec![ + Value::BulkString(b"Hi1".to_vec()), + Value::BulkString(b"Bye1".to_vec()), + Value::BulkString(b"Hi2".to_vec()), + Value::BulkString(b"Bye2".to_vec()), + ]), + Value::Array(vec![ + Value::BulkString(b"S1".to_vec()), + Value::BulkString(b"S2".to_vec()), + Value::BulkString(b"S3".to_vec()), + ]), + Value::Array(vec![ + Value::BulkString(b"Hi3".to_vec()), + Value::BulkString(b"Bye3".to_vec()), + Value::BulkString(b"Hi4".to_vec()), + Value::BulkString(b"Bye4".to_vec()), + ]), + Value::Array(vec![ + Value::BulkString(b"S4".to_vec()), + Value::BulkString(b"S5".to_vec()), + Value::BulkString(b"S6".to_vec()), + ]), + ]); + let res: Vec<(HashMap, Vec)> = + parse_mode.parse_redis_value(value).unwrap(); + + let mut expected_map1 = HashMap::new(); + expected_map1.insert("Hi1".to_string(), "Bye1".to_string()); + expected_map1.insert("Hi2".to_string(), "Bye2".to_string()); + + let mut expected_map2 = HashMap::new(); + expected_map2.insert("Hi3".to_string(), "Bye3".to_string()); + expected_map2.insert("Hi4".to_string(), "Bye4".to_string()); + + assert_eq!( + res, + vec![ + ( + expected_map1, + vec!["S1".to_string(), "S2".to_string(), "S3".to_string()] + ), + ( + expected_map2, + vec!["S4".to_string(), "S5".to_string(), "S6".to_string()] + ) + ] + ); + } + } + + #[test] + fn test_complex_nested_tuples_resp3() { + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let value = Value::Array(vec![ + Value::Map(vec![ + ( + Value::BulkString(b"Hi1".to_vec()), + Value::BulkString(b"Bye1".to_vec()), + ), + ( + Value::BulkString(b"Hi2".to_vec()), + Value::BulkString(b"Bye2".to_vec()), + ), + ]), + Value::Set(vec![ + Value::BulkString(b"S1".to_vec()), + Value::BulkString(b"S2".to_vec()), + Value::BulkString(b"S3".to_vec()), + ]), + Value::Map(vec![ + ( + Value::BulkString(b"Hi3".to_vec()), + Value::BulkString(b"Bye3".to_vec()), + ), + ( + Value::BulkString(b"Hi4".to_vec()), + Value::BulkString(b"Bye4".to_vec()), + ), + ]), + Value::Set(vec![ + Value::BulkString(b"S4".to_vec()), + Value::BulkString(b"S5".to_vec()), + Value::BulkString(b"S6".to_vec()), + ]), + ]); + let res: Vec<(HashMap, HashSet)> = + parse_mode.parse_redis_value(value).unwrap(); + + let mut expected_map1 = HashMap::new(); + expected_map1.insert("Hi1".to_string(), "Bye1".to_string()); + expected_map1.insert("Hi2".to_string(), "Bye2".to_string()); + let mut expected_set1 = HashSet::new(); + expected_set1.insert("S1".to_string()); + expected_set1.insert("S2".to_string()); + expected_set1.insert("S3".to_string()); + + let mut expected_map2 = HashMap::new(); + expected_map2.insert("Hi3".to_string(), "Bye3".to_string()); + expected_map2.insert("Hi4".to_string(), "Bye4".to_string()); + let mut expected_set2 = HashSet::new(); + expected_set2.insert("S4".to_string()); + expected_set2.insert("S5".to_string()); + expected_set2.insert("S6".to_string()); + + assert_eq!( + res, + vec![ + (expected_map1, expected_set1), + (expected_map2, expected_set2) + ] + ); + } + } }