diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 48d93d723..b846e9f9b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -36,6 +36,14 @@ jobs: /usr/bin/redis-server key: ${{ runner.os }}-redis + - name: Cache RedisJSON + id: cache-redisjson + uses: actions/cache@v2 + with: + path: | + /tmp/librejson.so + key: ${{ runner.os }}-redisjson + - name: Install redis if: steps.cache-redis.outputs.cache-hit != 'true' run: | @@ -55,7 +63,11 @@ jobs: - uses: Swatinem/rust-cache@v1 - uses: actions/checkout@v2 + - name: Run tests + run: make test + - name: Checkout RedisJSON + if: steps.cache-redisjson.outputs.cache-hit != 'true' uses: actions/checkout@v2 with: repository: "RedisJSON/RedisJSON" @@ -76,6 +88,7 @@ jobs: # This shouldn't cause issues in the future so long as no profiles or patches # are applied to the workspace Cargo.toml file - name: Compile RedisJSON + if: steps.cache-redisjson.outputs.cache-hit != 'true' run: | cp ./Cargo.toml ./Cargo.toml.actual echo $'\nexclude = [\"./__ci/redis-json\"]' >> Cargo.toml @@ -84,8 +97,8 @@ jobs: rm ./Cargo.toml; mv ./Cargo.toml.actual ./Cargo.toml rm -rf ./__ci/redis-json - - name: Run tests - run: make test + - name: Run module-specific tests + run: make test-module - name: Check features run: | diff --git a/LICENSE b/LICENSE index 13e2e6edb..533ac4e5a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2019 by Armin Ronacher, Jan-Erik Rediger. +Copyright (c) 2022 by redis-rs contributors Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. diff --git a/Makefile b/Makefile index f417120ea..b8cc74786 100644 --- a/Makefile +++ b/Makefile @@ -2,30 +2,58 @@ build: @cargo build test: + @echo "====================================================================" @echo "Testing Connection Type TCP without features" @echo "====================================================================" - @REDISRS_SERVER_TYPE=tcp cargo test --no-default-features --tests -- --nocapture --test-threads=1 + @REDISRS_SERVER_TYPE=tcp cargo test -p redis --no-default-features -- --nocapture --test-threads=1 @echo "====================================================================" @echo "Testing Connection Type TCP with all features" @echo "====================================================================" - @REDISRS_SERVER_TYPE=tcp cargo test --all-features -- --nocapture --test-threads=1 + @REDISRS_SERVER_TYPE=tcp cargo test -p redis --all-features -- --nocapture --test-threads=1 --skip test_module @echo "====================================================================" - @echo "Testing Connection Type TCP with all features and TLS support" + @echo "Testing Connection Type TCP with all features and Rustls support" @echo "====================================================================" - @REDISRS_SERVER_TYPE=tcp+tls cargo test --all-features -- --nocapture --test-threads=1 + @REDISRS_SERVER_TYPE=tcp+tls cargo test -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and native-TLS support" + @echo "====================================================================" + @REDISRS_SERVER_TYPE=tcp+tls cargo test -p redis --features=json,tokio-native-tls-comp,connection-manager,cluster-async -- --nocapture --test-threads=1 --skip test_module @echo "====================================================================" @echo "Testing Connection Type UNIX" @echo "====================================================================" - @REDISRS_SERVER_TYPE=unix cargo test --test parser --test test_basic --test test_types --all-features -- --test-threads=1 + @REDISRS_SERVER_TYPE=unix cargo test -p redis --test parser --test test_basic --test test_types --all-features -- --test-threads=1 --skip test_module @echo "====================================================================" @echo "Testing Connection Type UNIX SOCKETS" @echo "====================================================================" - @REDISRS_SERVER_TYPE=unix cargo test --all-features -- --skip test_cluster + @REDISRS_SERVER_TYPE=unix cargo test -p redis --all-features -- --skip test_cluster --skip test_async_cluster --skip test_module + + @echo "====================================================================" + @echo "Testing async-std with Rustls" + @echo "====================================================================" + @REDISRS_SERVER_TYPE=tcp cargo test -p redis --features=async-std-rustls-comp,cluster-async -- --nocapture --test-threads=1 + + @echo "====================================================================" + @echo "Testing async-std with native-TLS" + @echo "====================================================================" + @REDISRS_SERVER_TYPE=tcp cargo test -p redis --features=async-std-native-tls-comp,cluster-async -- --nocapture --test-threads=1 + + @echo "====================================================================" + @echo "Testing redis-test" + @echo "====================================================================" + @cargo test -p redis-test + + +test-module: + @echo "====================================================================" + @echo "Testing with module support enabled (currently only RedisJSON)" + @echo "====================================================================" + @REDISRS_SERVER_TYPE=tcp cargo test --all-features test_module test-single: test diff --git a/README.md b/README.md index 7454f3355..8ec283ae6 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,13 @@ The crate is called `redis` and you can depend on it via cargo: ```ini [dependencies] -redis = "0.22.0" +redis = "0.23.0" ``` Documentation on the library can be found at [docs.rs/redis](https://docs.rs/redis). -**Note: redis-rs requires at least Rust 1.51.** +**Note: redis-rs requires at least Rust 1.59.** ## Basic Operation @@ -50,52 +50,76 @@ fn fetch_an_integer() -> redis::RedisResult { ## Async support -To enable asynchronous clients a feature for the underlying feature need to be activated. +To enable asynchronous clients, enable the relevant feature in your Cargo.toml, +`tokio-comp` for tokio users or `async-std-comp` for async-std users. + ``` # if you use tokio -redis = { version = "0.22.0", features = ["tokio-comp"] } +redis = { version = "0.23.0", features = ["tokio-comp"] } # if you use async-std -redis = { version = "0.22.0", features = ["async-std-comp"] } +redis = { version = "0.23.0", features = ["async-std-comp"] } ``` ## TLS Support To enable TLS support, you need to use the relevant feature entry in your Cargo.toml. +Currently, `native-tls` and `rustls` are supported. + +To use `native-tls`: + +``` +redis = { version = "0.23.0", features = ["tls-native-tls"] } + +# if you use tokio +redis = { version = "0.23.0", features = ["tokio-native-tls-comp"] } + +# if you use async-std +redis = { version = "0.23.0", features = ["async-std-native-tls-comp"] } +``` + +To use `rustls`: ``` -redis = { version = "0.22.0", features = ["tls"] } +redis = { version = "0.23.0", features = ["tls-rustls"] } # if you use tokio -redis = { version = "0.22.0", features = ["tokio-native-tls-comp"] } +redis = { version = "0.23.0", features = ["tokio-rustls-comp"] } # if you use async-std -redis = { version = "0.22.0", features = ["async-std-tls-comp"] } +redis = { version = "0.23.0", features = ["async-std-rustls-comp"] } ``` +With `rustls`, you can add the following feature flags on top of other feature flags to enable additional features: +- `tls-rustls-insecure`: Allow insecure TLS connections +- `tls-rustls-webpki-roots`: Use `webpki-roots` (Mozilla's root certificates) instead of native root certificates + then you should be able to connect to a redis instance using the `rediss://` URL scheme: ```rust let client = redis::Client::open("rediss://127.0.0.1/")?; ``` +**Deprecation Notice:** If you were using the `tls` or `async-std-tls-comp` features, please use the `tls-native-tls` or `async-std-native-tls-comp` features respectively. + ## Cluster Support -Cluster mode can be used by specifying "cluster" as a features entry in your Cargo.toml. +Support for Redis Cluster can be enabled by enabling the `cluster` feature in your Cargo.toml: -`redis = { version = "0.22.0", features = [ "cluster"] }` +`redis = { version = "0.23.0", features = [ "cluster"] }` -Then you can simply use the `ClusterClient` which accepts a list of available nodes. +Then you can simply use the `ClusterClient`, which accepts a list of available nodes. Note +that only one node in the cluster needs to be specified when instantiating the client, though +you can specify multiple. ```rust use redis::cluster::ClusterClient; use redis::Commands; fn fetch_an_integer() -> String { - // connect to redis let nodes = vec!["redis://127.0.0.1/"]; - let client = ClusterClient::open(nodes).unwrap(); + let client = ClusterClient::new(nodes).unwrap(); let mut connection = client.get_connection().unwrap(); let _: () = connection.set("test", "test_data").unwrap(); let rv: String = connection.get("test").unwrap(); @@ -103,11 +127,30 @@ fn fetch_an_integer() -> String { } ``` +Async Redis Cluster support can be enabled by enabling the `cluster-async` feature, along +with your preferred async runtime, e.g.: + +`redis = { version = "0.23.0", features = [ "cluster-async", "tokio-std-comp" ] }` + +```rust +use redis::cluster::ClusterClient; +use redis::AsyncCommands; + +async fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_async_connection().await.unwrap(); + let _: () = connection.set("test", "test_data").await.unwrap(); + let rv: String = connection.get("test").await.unwrap(); + return rv; +} +``` + ## JSON Support Support for the RedisJSON Module can be enabled by specifying "json" as a feature in your Cargo.toml. -`redis = { version = "0.22.0", features = ["json"] }` +`redis = { version = "0.23.0", features = ["json"] }` Then you can simply import the `JsonCommands` trait which will add the `json` commands to all Redis Connections (not to be confused with just `Commands` which only adds the default commands) @@ -135,7 +178,7 @@ fn set_json_bool(key: P, path: P, b: bool) -> RedisResult ## Development To test `redis` you're going to need to be able to test with the Redis Modules, to do this -you must set the following envornment variables before running the test script +you must set the following environment variables before running the test script - `REDIS_RS_REDIS_JSON_PATH` = The absolute path to the RedisJSON module (Usually called `librejson.so`). diff --git a/redis-test/CHANGELOG.md b/redis-test/CHANGELOG.md index 9f370cd83..76ab12c4d 100644 --- a/redis-test/CHANGELOG.md +++ b/redis-test/CHANGELOG.md @@ -1,3 +1,21 @@ + +### 0.2.0 (2023-04-05) + +* Track redis 0.23.0 release + + +### 0.2.0-beta.1 (2023-03-28) + +* Track redis 0.23.0-beta.1 release + + +### 0.1.1 (2022-10-18) + +#### Changes +* Add README +* Update LICENSE file / symlink from parent directory + + ### 0.1.0 (2022-10-05) diff --git a/redis-test/Cargo.toml b/redis-test/Cargo.toml index 5ee04d724..cf094d160 100644 --- a/redis-test/Cargo.toml +++ b/redis-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "redis-test" -version = "0.1.0" +version = "0.2.0" edition = "2021" description = "Testing helpers for the `redis` crate" homepage = "https://github.com/redis-rs/redis-rs" @@ -10,7 +10,7 @@ license = "BSD-3-Clause" rust-version = "1.59" [dependencies] -redis = { version = "0.22.0", path = "../redis" } +redis = { version = "0.23.0", path = "../redis" } bytes = { version = "1", optional = true } futures = { version = "0.3", optional = true } @@ -19,6 +19,6 @@ futures = { version = "0.3", optional = true } aio = ["futures", "redis/aio"] [dev-dependencies] -redis = { version = "0.22.0", path = "../redis", features = ["aio", "tokio-comp"] } +redis = { version = "0.23.0", path = "../redis", features = ["aio", "tokio-comp"] } tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread"] } diff --git a/redis-test/LICENSE b/redis-test/LICENSE new file mode 120000 index 000000000..ea5b60640 --- /dev/null +++ b/redis-test/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/redis-test/README.md b/redis-test/README.md new file mode 100644 index 000000000..b89bfc4ed --- /dev/null +++ b/redis-test/README.md @@ -0,0 +1,4 @@ +# redis-test + +Testing utilities for the redis-rs crate. + diff --git a/redis-test/src/lib.rs b/redis-test/src/lib.rs index 180648fe3..49094e426 100644 --- a/redis-test/src/lib.rs +++ b/redis-test/src/lib.rs @@ -397,4 +397,27 @@ mod tests { .expect("success"); assert_eq!(results, vec!["hello", "world"]); } + + #[test] + fn pipeline_atomic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().atomic().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec![Value::Bulk( + vec!["hello", "world"] + .into_iter() + .map(|x| Value::Data(x.as_bytes().into())) + .collect(), + )]), + )]); + + let results: Vec = pipe() + .atomic() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } } diff --git a/redis/CHANGELOG.md b/redis/CHANGELOG.md index d7cf9d81d..f6c46389e 100644 --- a/redis/CHANGELOG.md +++ b/redis/CHANGELOG.md @@ -1,3 +1,95 @@ + +### 0.23.0 (2023-04-05) +In addition to *everything mentioned in 0.23.0-beta.1 notes*, this release adds support for Rustls, a long- +sought feature. Thanks to @rharish101 and @LeoRowan for getting this in! + +#### Changes +* Update Rustls to v0.21.0 ([#820](https://github.com/redis-rs/redis-rs/pull/820) @rharish101) +* Implement support for Rustls ([#725](https://github.com/redis-rs/redis-rs/pull/725) @rharish101, @LeoRowan) + + +### 0.23.0-beta.1 (2023-03-28) + +This release adds the `cluster_async` module, which introduces async Redis Cluster support. The code therein +is largely taken from @Marwes's [redis-cluster-async crate](https://github.com/redis-rs/redis-cluster-async), which itself +appears to have started from a sync Redis Cluster implementation started by @atuk721. In any case, thanks to @Marwes and @atuk721 +for the great work, and we hope to keep development moving forward in `redis-rs`. + +Though async Redis Cluster functionality for the time being has been kept as close to the originating crate as possible, previous users of +`redis-cluster-async` should note the following changes: +* Retries, while still configurable, can no longer be set to `None`/infinite retries +* Routing and slot parsing logic has been removed and merged with existing `redis-rs` functionality +* The client has been removed and superceded by common `ClusterClient` +* Renamed `Connection` to `ClusterConnection` +* Added support for reading from replicas +* Added support for insecure TLS +* Added support for setting both username and password + +#### Breaking Changes +* Fix long-standing bug related to `AsyncIter`'s stream implementation in which polling the server + for additional data yielded broken data in most cases. Type bounds for `AsyncIter` have changed slightly, + making this a potentially breaking change. ([#597](https://github.com/redis-rs/redis-rs/pull/597) @roger) + +#### Changes +* Commands: Add additional generic args for key arguments ([#795](https://github.com/redis-rs/redis-rs/pull/795) @MaxOhn) +* Add `mset` / deprecate `set_multiple` ([#766](https://github.com/redis-rs/redis-rs/pull/766) @randomairborne) +* More efficient interfaces for `MultiplexedConnection` and `ConnectionManager` ([#811](https://github.com/redis-rs/redis-rs/pull/811) @nihohit) +* Refactor / remove flaky test ([#810](https://github.com/redis-rs/redis-rs/pull/810)) +* `cluster_async`: rename `Connection` to `ClusterConnection`, `Pipeline` to `ClusterConnInner` ([#808](https://github.com/redis-rs/redis-rs/pull/808)) +* Support parsing IPV6 cluster nodes ([#796](https://github.com/redis-rs/redis-rs/pull/796) @socs) +* Common client for sync/async cluster connections ([#798](https://github.com/redis-rs/redis-rs/pull/798)) + * `cluster::ClusterConnection` underlying connection type is now generic (with existing type as default) + * Support `read_from_replicas` in cluster_async + * Set retries in `ClusterClientBuilder` + * Add mock tests for `cluster` +* cluster-async common slot parsing([#793](https://github.com/redis-rs/redis-rs/pull/793)) +* Support async-std in cluster_async module ([#790](https://github.com/redis-rs/redis-rs/pull/790)) +* Async-Cluster use same routing as Sync-Cluster ([#789](https://github.com/redis-rs/redis-rs/pull/789)) +* Add Async Cluster Support ([#696](https://github.com/redis-rs/redis-rs/pull/696)) +* Fix broken json-module tests ([#786](https://github.com/redis-rs/redis-rs/pull/786)) +* `cluster`: Tls Builder support / simplify cluster connection map ([#718](https://github.com/redis-rs/redis-rs/pull/718) @0xWOF, @utkarshgupta137) + + +### 0.22.3 (2023-01-23) + +#### Changes +* Restore inherent `ClusterConnection::check_connection()` method ([#758](https://github.com/redis-rs/redis-rs/pull/758) @robjtede) + + + +### 0.22.2 (2023-01-07) + +This release adds various incremental improvements and fixes a few long-standing bugs. Thanks to all our +contributors for making this release happen. + +#### Features +* Implement ToRedisArgs for HashMap ([#722](https://github.com/redis-rs/redis-rs/pull/722) @gibranamparan) +* Add explicit `MGET` command ([#729](https://github.com/redis-rs/redis-rs/pull/729) @vamshiaruru-virgodesigns) + +#### Bug fixes +* Enable single-item-vector `get` responses ([#507](https://github.com/redis-rs/redis-rs/pull/507) @hank121314) +* Fix empty result from xread_options with deleted entries ([#712](https://github.com/redis-rs/redis-rs/pull/712) @Quiwin) +* Limit Parser Recursion ([#724](https://github.com/redis-rs/redis-rs/pull/724)) +* Improve MultiplexedConnection Error Handling ([#699](https://github.com/redis-rs/redis-rs/pull/699)) + +#### Changes +* Add test case for atomic pipeline ([#702](https://github.com/redis-rs/redis-rs/pull/702) @CNLHC) +* Capture subscribe result error in PubSub doc example ([#739](https://github.com/redis-rs/redis-rs/pull/739) @baoyachi) +* Use async-std name resolution when necessary ([#701](https://github.com/redis-rs/redis-rs/pull/701) @UgnilJoZ) +* Add Script::invoke_async method ([#711](https://github.com/redis-rs/redis-rs/pull/711) @r-bk) +* Cluster Refactorings ([#717](https://github.com/redis-rs/redis-rs/pull/717), [#716](https://github.com/redis-rs/redis-rs/pull/716), [#709](https://github.com/redis-rs/redis-rs/pull/709), [#707](https://github.com/redis-rs/redis-rs/pull/707), [#706](https://github.com/redis-rs/redis-rs/pull/706) @0xWOF, @utkarshgupta137) +* Fix intermitent test failure ([#714](https://github.com/redis-rs/redis-rs/pull/714) @0xWOF, @utkarshgupta137) +* Doc changes ([#705](https://github.com/redis-rs/redis-rs/pull/705) @0xWOF, @utkarshgupta137) +* Lint fixes ([#704](https://github.com/redis-rs/redis-rs/pull/704) @0xWOF) + + + +### 0.22.1 (2022-10-18) + +#### Changes +* Add README attribute to Cargo.toml +* Update LICENSE file / symlink from parent directory + ### 0.22.0 (2022-10-05) diff --git a/redis/Cargo.toml b/redis/Cargo.toml index 4989ea0c4..c9d33dae9 100644 --- a/redis/Cargo.toml +++ b/redis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "redis" -version = "0.22.0" +version = "0.23.0" keywords = ["redis", "database"] description = "Redis driver for Rust." homepage = "https://github.com/redis-rs/redis-rs" @@ -9,6 +9,7 @@ documentation = "https://docs.rs/redis" license = "BSD-3-Clause" edition = "2021" rust-version = "1.59" +readme = "../README.md" [package.metadata.docs.rs] all-features = true @@ -53,11 +54,18 @@ rand = { version = "0.8", optional = true } async-std = { version = "1.8.0", optional = true} async-trait = { version = "0.1.24", optional = true } -# Only needed for TLS +# Only needed for native tls native-tls = { version = "0.2", optional = true } tokio-native-tls = { version = "0.3", optional = true } async-native-tls = { version = "0.4", optional = true } +# Only needed for rustls +rustls = { version = "0.21.0", optional = true } +webpki-roots = { version = "0.23.0", optional = true } +rustls-native-certs = { version = "0.6.2", optional = true } +tokio-rustls = { version = "0.24.0", optional = true } +futures-rustls = { version = "0.24.0", optional = true } + # Only needed for RedisJSON Support serde = { version = "1.0.82", optional = true } serde_json = { version = "1.0.82", optional = true } @@ -65,22 +73,33 @@ serde_json = { version = "1.0.82", optional = true } # Optional aHash support ahash = { version = "0.7.6", optional = true } +log = { version = "0.4", optional = true } + [features] default = ["acl", "streams", "geospatial", "script"] acl = [] aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "tokio/sync", "combine/tokio", "async-trait"] geospatial = [] -json = ["serde", "serde_json"] +json = ["serde", "serde/derive", "serde_json"] cluster = ["crc16", "rand"] script = ["sha1_smol"] -tls = ["native-tls"] +tls-native-tls = ["native-tls"] +tls-rustls = ["rustls", "rustls-native-certs"] +tls-rustls-insecure = ["tls-rustls", "rustls/dangerous_configuration"] +tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"] async-std-comp = ["aio", "async-std"] -async-std-tls-comp = ["async-std-comp", "async-native-tls", "tls"] +async-std-native-tls-comp = ["async-std-comp", "async-native-tls", "tls-native-tls"] +async-std-rustls-comp = ["async-std-comp", "futures-rustls", "tls-rustls"] tokio-comp = ["aio", "tokio", "tokio/net"] -tokio-native-tls-comp = ["tls", "tokio-native-tls"] +tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "tokio-native-tls"] +tokio-rustls-comp = ["tokio-comp", "tls-rustls", "tokio-rustls"] connection-manager = ["arc-swap", "futures", "aio"] streams = [] +cluster-async = ["cluster", "futures", "futures-util", "log"] +# Deprecated features +tls = ["tls-native-tls"] # use "tls-native-tls" instead +async-std-tls-comp = ["async-std-native-tls-comp"] # use "async-std-native-tls-comp" instead [dev-dependencies] rand = "0.8" @@ -88,11 +107,13 @@ socket2 = "0.4" assert_approx_eq = "1.0" fnv = "1.0.5" futures = "0.3" -criterion = "0.3" +criterion = "0.4" partial-io = { version = "0.5", features = ["tokio", "quickcheck1"] } quickcheck = "1.0.3" tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } tempfile = "3.2" +once_cell = "1" +anyhow = "1" [[test]] name = "test_async" @@ -110,9 +131,13 @@ required-features = ["aio"] name = "test_acl" [[test]] -name = "test_json" +name = "test_module_json" required-features = ["json", "serde/derive"] +[[test]] +name = "test_cluster_async" +required-features = ["cluster-async"] + [[bench]] name = "bench_basic" harness = false diff --git a/redis/LICENSE b/redis/LICENSE new file mode 120000 index 000000000..ea5b60640 --- /dev/null +++ b/redis/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/redis/benches/bench_basic.rs b/redis/benches/bench_basic.rs index 1ecbeb06e..946c76ba4 100644 --- a/redis/benches/bench_basic.rs +++ b/redis/benches/bench_basic.rs @@ -93,7 +93,7 @@ fn long_pipeline() -> redis::Pipeline { let mut pipe = redis::pipe(); for i in 0..PIPELINE_QUERIES { - pipe.set(format!("foo{}", i), "bar").ignore(); + pipe.set(format!("foo{i}"), "bar").ignore(); } pipe } @@ -147,7 +147,7 @@ fn bench_multiplexed_async_implicit_pipeline(b: &mut Bencher) { .unwrap(); let cmds: Vec<_> = (0..PIPELINE_QUERIES) - .map(|i| redis::cmd("SET").arg(format!("foo{}", i)).arg(i).clone()) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) .collect(); let mut connections = (0..PIPELINE_QUERIES) diff --git a/redis/benches/bench_cluster.rs b/redis/benches/bench_cluster.rs index 9717f8366..b9c1280dd 100644 --- a/redis/benches/bench_cluster.rs +++ b/redis/benches/bench_cluster.rs @@ -46,7 +46,7 @@ fn bench_pipeline(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection let mut queries = Vec::new(); for i in 0..PIPELINE_QUERIES { - queries.push(format!("foo{}", i)); + queries.push(format!("foo{i}")); } let build_pipeline = || { diff --git a/redis/examples/async-connection-loss.rs b/redis/examples/async-connection-loss.rs index 45079c359..b84b5d319 100644 --- a/redis/examples/async-connection-loss.rs +++ b/redis/examples/async-connection-loss.rs @@ -28,7 +28,7 @@ async fn run_single(mut con: C) -> RedisResult<()> { println!(); println!("> PING"); let result: RedisResult = redis::cmd("PING").query_async(&mut con).await; - println!("< {:?}", result); + println!("< {result:?}"); } } diff --git a/redis/examples/async-multiplexed.rs b/redis/examples/async-multiplexed.rs index f6aea4114..6702fa722 100644 --- a/redis/examples/async-multiplexed.rs +++ b/redis/examples/async-multiplexed.rs @@ -4,9 +4,9 @@ use redis::{aio::MultiplexedConnection, RedisResult}; async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { let mut con = con.clone(); - let key = format!("key{}", i); - let key2 = format!("key{}_2", i); - let value = format!("foo{}", i); + let key = format!("key{i}"); + let key2 = format!("key{i}_2"); + let value = format!("foo{i}"); redis::cmd("SET") .arg(&key[..]) diff --git a/redis/examples/basic.rs b/redis/examples/basic.rs index e621e2949..50ccbb6f5 100644 --- a/redis/examples/basic.rs +++ b/redis/examples/basic.rs @@ -64,7 +64,7 @@ fn do_show_scanning(con: &mut redis::Connection) -> redis::RedisResult<()> { // type of the iterator, rust will figure "int" out for us. let sum: i32 = cmd.iter::(con)?.sum(); - println!("The sum of all numbers in the set 0-1000: {}", sum); + println!("The sum of all numbers in the set 0-1000: {sum}"); Ok(()) } @@ -103,7 +103,7 @@ fn do_atomic_increment_lowlevel(con: &mut redis::Connection) -> redis::RedisResu } Some(response) => { let (new_val,) = response; - println!(" New value: {}", new_val); + println!(" New value: {new_val}"); break; } } @@ -129,7 +129,7 @@ fn do_atomic_increment(con: &mut redis::Connection) -> redis::RedisResult<()> { })?; // and print the result - println!("New value: {}", new_val); + println!("New value: {new_val}"); Ok(()) } diff --git a/redis/examples/geospatial.rs b/redis/examples/geospatial.rs index 58faab296..b2d408af3 100644 --- a/redis/examples/geospatial.rs +++ b/redis/examples/geospatial.rs @@ -27,12 +27,12 @@ fn run() -> RedisResult<()> { ], )?; - println!("[geo_add] Added {} members.", added); + println!("[geo_add] Added {added} members."); // Get the position of one of them. let position: Vec> = con.geo_pos("gis", "Palermo")?; - println!("[geo_pos] Position for Palermo: {:?}", position); + println!("[geo_pos] Position for Palermo: {position:?}"); // Search members near (13.5, 37.75) @@ -61,7 +61,7 @@ fn run() -> RedisResult<()> { fn main() { if let Err(e) = run() { - println!("{:?}", e); + println!("{e:?}"); exit(1); } } diff --git a/redis/examples/streams.rs b/redis/examples/streams.rs index e0b3a5b01..d22c0601e 100644 --- a/redis/examples/streams.rs +++ b/redis/examples/streams.rs @@ -77,7 +77,7 @@ fn demo_group_reads(client: &redis::Client) { for key in STREAMS { let created: Result<(), _> = con.xgroup_create_mkstream(*key, GROUP_NAME, "$"); if let Err(e) = created { - println!("Group already exists: {:?}", e) + println!("Group already exists: {e:?}") } } @@ -216,9 +216,9 @@ fn read_records(client: &redis::Client) -> RedisResult<()> { .expect("read"); for StreamKey { key, ids } in srr.keys { - println!("Stream {}", key); + println!("Stream {key}"); for StreamId { id, map } in ids { - println!("\tID {}", id); + println!("\tID {id}"); for (n, s) in map { if let Value::Data(bytes) = s { println!("\t\t{}: {}", n, String::from_utf8(bytes).expect("utf8")) @@ -233,7 +233,7 @@ fn read_records(client: &redis::Client) -> RedisResult<()> { } fn consumer_name(slowness: u8) -> String { - format!("example-consumer-{}", slowness) + format!("example-consumer-{slowness}") } const GROUP_NAME: &str = "example-group-aaa"; diff --git a/redis/src/acl.rs b/redis/src/acl.rs index 00f519586..2e2e984a7 100644 --- a/redis/src/acl.rs +++ b/redis/src/acl.rs @@ -81,21 +81,21 @@ impl ToRedisArgs for Rule { On => out.write_arg(b"on"), Off => out.write_arg(b"off"), - AddCommand(cmd) => out.write_arg_fmt(format_args!("+{}", cmd)), - RemoveCommand(cmd) => out.write_arg_fmt(format_args!("-{}", cmd)), - AddCategory(cat) => out.write_arg_fmt(format_args!("+@{}", cat)), - RemoveCategory(cat) => out.write_arg_fmt(format_args!("-@{}", cat)), + AddCommand(cmd) => out.write_arg_fmt(format_args!("+{cmd}")), + RemoveCommand(cmd) => out.write_arg_fmt(format_args!("-{cmd}")), + AddCategory(cat) => out.write_arg_fmt(format_args!("+@{cat}")), + RemoveCategory(cat) => out.write_arg_fmt(format_args!("-@{cat}")), AllCommands => out.write_arg(b"allcommands"), NoCommands => out.write_arg(b"nocommands"), - AddPass(pass) => out.write_arg_fmt(format_args!(">{}", pass)), - RemovePass(pass) => out.write_arg_fmt(format_args!("<{}", pass)), - AddHashedPass(pass) => out.write_arg_fmt(format_args!("#{}", pass)), - RemoveHashedPass(pass) => out.write_arg_fmt(format_args!("!{}", pass)), + AddPass(pass) => out.write_arg_fmt(format_args!(">{pass}")), + RemovePass(pass) => out.write_arg_fmt(format_args!("<{pass}")), + AddHashedPass(pass) => out.write_arg_fmt(format_args!("#{pass}")), + RemoveHashedPass(pass) => out.write_arg_fmt(format_args!("!{pass}")), NoPass => out.write_arg(b"nopass"), ResetPass => out.write_arg(b"resetpass"), - Pattern(pat) => out.write_arg_fmt(format_args!("~{}", pat)), + Pattern(pat) => out.write_arg_fmt(format_args!("~{pat}")), AllKeys => out.write_arg(b"allkeys"), ResetKeys => out.write_arg(b"resetkeys"), diff --git a/redis/src/aio.rs b/redis/src/aio.rs index 1c6b63d64..6534e76ca 100644 --- a/redis/src/aio.rs +++ b/redis/src/aio.rs @@ -4,7 +4,6 @@ use std::collections::VecDeque; use std::fmt; use std::fmt::Debug; use std::io; -use std::mem; use std::net::SocketAddr; #[cfg(unix)] use std::path::Path; @@ -13,15 +12,13 @@ use std::task::{self, Poll}; use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset}; +#[cfg(feature = "tokio-comp")] +use ::tokio::net::lookup_host; use ::tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, - net::lookup_host, sync::{mpsc, oneshot}, }; -#[cfg(feature = "tls")] -use native_tls::TlsConnector; - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use tokio_util::codec::Decoder; @@ -47,6 +44,9 @@ use crate::{from_redis_value, ToRedisArgs}; #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] pub mod async_std; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use ::async_std::net::ToSocketAddrs; + /// Enables the tokio compatibility #[cfg(feature = "tokio-comp")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] @@ -59,7 +59,7 @@ pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static { async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult; // Performs a TCP TLS connection - #[cfg(feature = "tls")] + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, @@ -185,7 +185,7 @@ where /// The message itself is still generic and can be converted into an appropriate type through /// the helper methods on it. /// This can be useful in cases where the stream needs to be returned or held by something other - // than the [`PubSub`]. + /// than the [`PubSub`]. pub fn into_on_message(self) -> impl Stream { ValueCodec::default() .framed(self.0.con) @@ -459,7 +459,7 @@ pub(crate) async fn connect_simple( ::connect_tcp(socket_addr).await? } - #[cfg(feature = "tls")] + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] ConnectionAddr::TcpTls { ref host, port, @@ -469,7 +469,7 @@ pub(crate) async fn connect_simple( ::connect_tcp_tls(host, socket_addr, insecure).await? } - #[cfg(not(feature = "tls"))] + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] ConnectionAddr::TcpTls { .. } => { fail!(( ErrorKind::InvalidClientConfig, @@ -492,7 +492,10 @@ pub(crate) async fn connect_simple( } async fn get_socket_addrs(host: &str, port: u16) -> RedisResult { + #[cfg(feature = "tokio-comp")] let mut socket_addrs = lookup_host((host, port)).await?; + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + let mut socket_addrs = (host, port).to_socket_addrs().await?; match socket_addrs.next() { Some(socket_addr) => Ok(socket_addr), None => Err(RedisError::from(( @@ -602,8 +605,22 @@ type PipelineOutput = oneshot::Sender, E>>; struct InFlight { output: PipelineOutput, - response_count: usize, + expected_response_count: usize, + current_response_count: usize, buffer: Vec, + first_err: Option, +} + +impl InFlight { + fn new(output: PipelineOutput, expected_response_count: usize) -> Self { + Self { + output, + expected_response_count, + current_response_count: 0, + buffer: Vec::new(), + first_err: None, + } + } } // A single message sent through the pipeline @@ -679,26 +696,37 @@ where fn send_result(self: Pin<&mut Self>, result: Result) { let self_ = self.project(); - let response = { + + { let entry = match self_.in_flight.front_mut() { Some(entry) => entry, None => return, }; + match result { Ok(item) => { entry.buffer.push(item); - if entry.response_count > entry.buffer.len() { - // Need to gather more response values - return; + } + Err(err) => { + if entry.first_err.is_none() { + entry.first_err = Some(err); } - Ok(mem::take(&mut entry.buffer)) } - // If we fail we must respond immediately - Err(err) => Err(err), } - }; + + entry.current_response_count += 1; + if entry.current_response_count < entry.expected_response_count { + // Need to gather more response values + return; + } + } let entry = self_.in_flight.pop_front().unwrap(); + let response = match entry.first_err { + Some(err) => Err(err), + None => Ok(entry.buffer), + }; + // `Err` means that the receiver was dropped in which case it does not // care about the output and we can continue by just dropping the value // and sender @@ -750,11 +778,9 @@ where match self_.sink_stream.start_send(input) { Ok(()) => { - self_.in_flight.push_back(InFlight { - output, - response_count, - buffer: Vec::new(), - }); + self_ + .in_flight + .push_back(InFlight::new(output, response_count)); Ok(()) } Err(err) => { @@ -913,23 +939,45 @@ impl MultiplexedConnection { }; Ok((con, driver)) } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + let value = self + .pipeline + .send(cmd.get_packed_command()) + .await + .map_err(|err| { + err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + })?; + Ok(value) + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + let mut value = self + .pipeline + .send_recv_multiple(cmd.get_packed_pipeline(), offset + count) + .await + .map_err(|err| { + err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + })?; + + value.drain(..offset); + Ok(value) + } } impl ConnectionLike for MultiplexedConnection { fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { - (async move { - let value = self - .pipeline - .send(cmd.get_packed_command()) - .await - .map_err(|err| { - err.unwrap_or_else(|| { - RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)) - }) - })?; - Ok(value) - }) - .boxed() + (async move { self.send_packed_command(cmd).await }).boxed() } fn req_packed_commands<'a>( @@ -938,21 +986,7 @@ impl ConnectionLike for MultiplexedConnection { offset: usize, count: usize, ) -> RedisFuture<'a, Vec> { - (async move { - let mut value = self - .pipeline - .send_recv_multiple(cmd.get_packed_pipeline(), offset + count) - .await - .map_err(|err| { - err.unwrap_or_else(|| { - RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)) - }) - })?; - - value.drain(..offset); - Ok(value) - }) - .boxed() + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() } fn get_db(&self) -> i64 { @@ -1016,6 +1050,30 @@ mod connection_manager { /// Type alias for a shared boxed future that will resolve to a `CloneableRedisResult`. type SharedRedisFuture = Shared>>; + /// Handle a command result. If the connection was dropped, reconnect. + macro_rules! reconnect_if_dropped { + ($self:expr, $result:expr, $current:expr) => { + if let Err(ref e) = $result { + if e.is_connection_dropped() { + $self.reconnect($current); + } + } + }; + } + + /// Handle a connection result. If there's an I/O error, reconnect. + /// Propagate any error. + macro_rules! reconnect_if_io_error { + ($self:expr, $result:expr, $current:expr) => { + if let Err(e) = $result { + if e.is_io_error() { + $self.reconnect($current); + } + return Err(e); + } + }; + } + impl ConnectionManager { /// Connect to the server and store the connection inside the returned `ConnectionManager`. /// @@ -1063,47 +1121,49 @@ mod connection_manager { self.runtime.spawn(new_connection.map(|_| ())); } } - } - /// Handle a command result. If the connection was dropped, reconnect. - macro_rules! reconnect_if_dropped { - ($self:expr, $result:expr, $current:expr) => { - if let Err(ref e) = $result { - if e.is_connection_dropped() { - $self.reconnect($current); - } - } - }; - } + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + // Clone connection to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_io_error!(self, connection_result, guard); + let result = connection_result?.send_packed_command(cmd).await; + reconnect_if_dropped!(self, &result, guard); + result + } - /// Handle a connection result. If there's an I/O error, reconnect. - /// Propagate any error. - macro_rules! reconnect_if_io_error { - ($self:expr, $result:expr, $current:expr) => { - if let Err(e) = $result { - if e.is_io_error() { - $self.reconnect($current); - } - return Err(e); - } - }; + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + // Clone shared connection future to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_io_error!(self, connection_result, guard); + let result = connection_result? + .send_packed_commands(cmd, offset, count) + .await; + reconnect_if_dropped!(self, &result, guard); + result + } } impl ConnectionLike for ConnectionManager { fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { - (async move { - // Clone connection to avoid having to lock the ArcSwap in write mode - let guard = self.connection.load(); - let connection_result = (**guard) - .clone() - .await - .map_err(|e| e.clone_mostly("Reconnecting failed")); - reconnect_if_io_error!(self, connection_result, guard); - let result = connection_result?.req_packed_command(cmd).await; - reconnect_if_dropped!(self, &result, guard); - result - }) - .boxed() + (async move { self.send_packed_command(cmd).await }).boxed() } fn req_packed_commands<'a>( @@ -1112,21 +1172,7 @@ mod connection_manager { offset: usize, count: usize, ) -> RedisFuture<'a, Vec> { - (async move { - // Clone shared connection future to avoid having to lock the ArcSwap in write mode - let guard = self.connection.load(); - let connection_result = (**guard) - .clone() - .await - .map_err(|e| e.clone_mostly("Reconnecting failed")); - reconnect_if_io_error!(self, connection_result, guard); - let result = connection_result? - .req_packed_commands(cmd, offset, count) - .await; - reconnect_if_dropped!(self, &result, guard); - result - }) - .boxed() + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() } fn get_db(&self) -> i64 { diff --git a/redis/src/aio/async_std.rs b/redis/src/aio/async_std.rs index 7b5b272e5..5f949b15b 100644 --- a/redis/src/aio/async_std.rs +++ b/redis/src/aio/async_std.rs @@ -1,5 +1,7 @@ #[cfg(unix)] use std::path::Path; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; use std::{ future::Future, io, @@ -10,8 +12,15 @@ use std::{ use crate::aio::{AsyncStream, RedisRuntime}; use crate::types::RedisResult; -#[cfg(feature = "tls")] + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] use async_native_tls::{TlsConnector, TlsStream}; + +#[cfg(feature = "tls-rustls")] +use crate::connection::create_rustls_config; +#[cfg(feature = "tls-rustls")] +use futures_rustls::{client::TlsStream, TlsConnector}; + use async_std::net::TcpStream; #[cfg(unix)] use async_std::os::unix::net::UnixStream; @@ -82,7 +91,10 @@ pub enum AsyncStd { /// Represents an Async_std TCP connection. Tcp(AsyncStdWrapped), /// Represents an Async_std TLS encrypted TCP connection. - #[cfg(feature = "async-std-tls-comp")] + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] TcpTls(AsyncStdWrapped>>), /// Represents an Async_std Unix connection. #[cfg(unix)] @@ -97,7 +109,10 @@ impl AsyncWrite for AsyncStd { ) -> Poll> { match &mut *self { AsyncStd::Tcp(r) => Pin::new(r).poll_write(cx, buf), - #[cfg(feature = "async-std-tls-comp")] + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] AsyncStd::TcpTls(r) => Pin::new(r).poll_write(cx, buf), #[cfg(unix)] AsyncStd::Unix(r) => Pin::new(r).poll_write(cx, buf), @@ -107,7 +122,10 @@ impl AsyncWrite for AsyncStd { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { match &mut *self { AsyncStd::Tcp(r) => Pin::new(r).poll_flush(cx), - #[cfg(feature = "async-std-tls-comp")] + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] AsyncStd::TcpTls(r) => Pin::new(r).poll_flush(cx), #[cfg(unix)] AsyncStd::Unix(r) => Pin::new(r).poll_flush(cx), @@ -117,7 +135,10 @@ impl AsyncWrite for AsyncStd { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { match &mut *self { AsyncStd::Tcp(r) => Pin::new(r).poll_shutdown(cx), - #[cfg(feature = "async-std-tls-comp")] + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] AsyncStd::TcpTls(r) => Pin::new(r).poll_shutdown(cx), #[cfg(unix)] AsyncStd::Unix(r) => Pin::new(r).poll_shutdown(cx), @@ -133,7 +154,10 @@ impl AsyncRead for AsyncStd { ) -> Poll> { match &mut *self { AsyncStd::Tcp(r) => Pin::new(r).poll_read(cx, buf), - #[cfg(feature = "async-std-tls-comp")] + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] AsyncStd::TcpTls(r) => Pin::new(r).poll_read(cx, buf), #[cfg(unix)] AsyncStd::Unix(r) => Pin::new(r).poll_read(cx, buf), @@ -149,7 +173,7 @@ impl RedisRuntime for AsyncStd { .map(|con| Self::Tcp(AsyncStdWrapped::new(con)))?) } - #[cfg(feature = "tls")] + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, @@ -170,6 +194,23 @@ impl RedisRuntime for AsyncStd { .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) } + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + ) -> RedisResult { + let tcp_stream = TcpStream::connect(&socket_addr).await?; + + let config = create_rustls_config(insecure)?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect(hostname.try_into()?, tcp_stream) + .await + .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) + } + #[cfg(unix)] async fn connect_unix(path: &Path) -> RedisResult { Ok(UnixStream::connect(path) @@ -184,7 +225,10 @@ impl RedisRuntime for AsyncStd { fn boxed(self) -> Pin> { match self { AsyncStd::Tcp(x) => Box::pin(x), - #[cfg(feature = "async-std-tls-comp")] + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] AsyncStd::TcpTls(x) => Box::pin(x), #[cfg(unix)] AsyncStd::Unix(x) => Box::pin(x), diff --git a/redis/src/aio/tokio.rs b/redis/src/aio/tokio.rs index 0e5afbd74..003bcc210 100644 --- a/redis/src/aio/tokio.rs +++ b/redis/src/aio/tokio.rs @@ -15,10 +15,17 @@ use tokio::{ net::TcpStream as TcpStreamTokio, }; -#[cfg(feature = "tls")] -use super::TlsConnector; +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::TlsConnector; -#[cfg(feature = "tokio-native-tls-comp")] +#[cfg(feature = "tls-rustls")] +use crate::connection::create_rustls_config; +#[cfg(feature = "tls-rustls")] +use std::{convert::TryInto, sync::Arc}; +#[cfg(feature = "tls-rustls")] +use tokio_rustls::{client::TlsStream, TlsConnector}; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] use tokio_native_tls::TlsStream; #[cfg(unix)] @@ -28,7 +35,7 @@ pub(crate) enum Tokio { /// Represents a Tokio TCP connection. Tcp(TcpStreamTokio), /// Represents a Tokio TLS encrypted TCP connection - #[cfg(feature = "tokio-native-tls-comp")] + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] TcpTls(Box>), /// Represents a Tokio Unix connection. #[cfg(unix)] @@ -43,7 +50,7 @@ impl AsyncWrite for Tokio { ) -> Poll> { match &mut *self { Tokio::Tcp(r) => Pin::new(r).poll_write(cx, buf), - #[cfg(feature = "tokio-native-tls-comp")] + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] Tokio::TcpTls(r) => Pin::new(r).poll_write(cx, buf), #[cfg(unix)] Tokio::Unix(r) => Pin::new(r).poll_write(cx, buf), @@ -53,7 +60,7 @@ impl AsyncWrite for Tokio { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { match &mut *self { Tokio::Tcp(r) => Pin::new(r).poll_flush(cx), - #[cfg(feature = "tokio-native-tls-comp")] + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] Tokio::TcpTls(r) => Pin::new(r).poll_flush(cx), #[cfg(unix)] Tokio::Unix(r) => Pin::new(r).poll_flush(cx), @@ -63,7 +70,7 @@ impl AsyncWrite for Tokio { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { match &mut *self { Tokio::Tcp(r) => Pin::new(r).poll_shutdown(cx), - #[cfg(feature = "tokio-native-tls-comp")] + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] Tokio::TcpTls(r) => Pin::new(r).poll_shutdown(cx), #[cfg(unix)] Tokio::Unix(r) => Pin::new(r).poll_shutdown(cx), @@ -79,7 +86,7 @@ impl AsyncRead for Tokio { ) -> Poll> { match &mut *self { Tokio::Tcp(r) => Pin::new(r).poll_read(cx, buf), - #[cfg(feature = "tokio-native-tls-comp")] + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] Tokio::TcpTls(r) => Pin::new(r).poll_read(cx, buf), #[cfg(unix)] Tokio::Unix(r) => Pin::new(r).poll_read(cx, buf), @@ -95,7 +102,7 @@ impl RedisRuntime for Tokio { .map(Tokio::Tcp)?) } - #[cfg(feature = "tls")] + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] async fn connect_tcp_tls( hostname: &str, socket_addr: SocketAddr, @@ -117,6 +124,24 @@ impl RedisRuntime for Tokio { .map(|con| Tokio::TcpTls(Box::new(con)))?) } + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + ) -> RedisResult { + let config = create_rustls_config(insecure)?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + hostname.try_into()?, + TcpStreamTokio::connect(&socket_addr).await?, + ) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + #[cfg(unix)] async fn connect_unix(path: &Path) -> RedisResult { Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?) @@ -135,7 +160,7 @@ impl RedisRuntime for Tokio { fn boxed(self) -> Pin> { match self { Tokio::Tcp(x) => Box::pin(x), - #[cfg(feature = "tokio-native-tls-comp")] + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] Tokio::TcpTls(x) => Box::pin(x), #[cfg(unix)] Tokio::Unix(x) => Box::pin(x), diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index aa9520548..f7c596763 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -1,9 +1,6 @@ -//! Redis cluster support. +//! This module extends the library to support Redis Cluster. //! -//! This module extends the library to be able to use cluster. -//! ClusterClient implements traits of ConnectionLike and Commands. -//! -//! Note that the cluster support currently does not provide pubsub +//! Note that this module does not currently provide pubsub //! functionality. //! //! # Example @@ -39,35 +36,94 @@ //! .query(&mut connection).unwrap(); //! ``` use std::cell::RefCell; -use std::collections::BTreeMap; use std::iter::Iterator; +use std::str::FromStr; use std::thread; use std::time::Duration; -use rand::{ - seq::{IteratorRandom, SliceRandom}, - thread_rng, Rng, -}; +use rand::{seq::IteratorRandom, thread_rng, Rng}; -use super::{ - cmd, parse_redis_value, - types::{HashMap, HashSet}, - Cmd, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, ErrorKind, IntoConnectionInfo, - RedisError, RedisResult, Value, +use crate::cluster_pipeline::UNROUTABLE_ERROR; +use crate::cluster_routing::{Routable, RoutingInfo, Slot, SLOT_SIZE}; +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo, +}; +use crate::parser::parse_redis_value; +use crate::types::{ErrorKind, HashMap, HashSet, RedisError, RedisResult, Value}; +use crate::IntoConnectionInfo; +use crate::{ + cluster_client::ClusterParams, + cluster_routing::{Route, SlotAddr, SlotAddrs, SlotMap}, }; -use crate::cluster_client::ClusterParams; pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; -use crate::cluster_pipeline::UNROUTABLE_ERROR; pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; -use crate::cluster_routing::{Routable, RoutingInfo, Slot, SLOT_SIZE}; -type SlotMap = BTreeMap; +/// Implements the process of connecting to a Redis server +/// and obtaining and configuring a connection handle. +pub trait Connect: Sized { + /// Connect to a node, returning handle for command execution. + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo; + + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()>; + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_write_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_read_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + fn recv_response(&mut self) -> RedisResult; +} + +impl Connect for Connection { + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + connect(&info.into_connection_info()?, timeout) + } + + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + Self::send_packed_command(self, cmd) + } + + fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_write_timeout(self, dur) + } + + fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_read_timeout(self, dur) + } -/// This is a connection of Redis cluster. -pub struct ClusterConnection { + fn recv_response(&mut self) -> RedisResult { + Self::recv_response(self) + } +} + +/// This represents a Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +pub struct ClusterConnection { initial_nodes: Vec, - connections: RefCell>, + connections: RefCell>, slots: RefCell, auto_reconnect: RefCell, read_from_replicas: bool, @@ -76,38 +132,19 @@ pub struct ClusterConnection { read_timeout: RefCell>, write_timeout: RefCell>, tls: Option, + retries: u32, } -#[derive(Clone, Copy)] -enum TlsMode { - Secure, - Insecure, -} - -impl TlsMode { - fn from_insecure_flag(insecure: bool) -> TlsMode { - if insecure { - TlsMode::Insecure - } else { - TlsMode::Secure - } - } -} - -impl ClusterConnection { +impl ClusterConnection +where + C: ConnectionLike + Connect, +{ pub(crate) fn new( cluster_params: ClusterParams, initial_nodes: Vec, - ) -> RedisResult { - let connections = Self::create_initial_connections( - &initial_nodes, - cluster_params.read_from_replicas, - cluster_params.username.clone(), - cluster_params.password.clone(), - )?; - - let connection = ClusterConnection { - connections: RefCell::new(connections), + ) -> RedisResult { + let connection = Self { + connections: RefCell::new(HashMap::new()), slots: RefCell::new(SlotMap::new()), auto_reconnect: RefCell::new(true), read_from_replicas: cluster_params.read_from_replicas, @@ -115,28 +152,11 @@ impl ClusterConnection { password: cluster_params.password, read_timeout: RefCell::new(None), write_timeout: RefCell::new(None), - #[cfg(feature = "tls")] - tls: { - if initial_nodes.is_empty() { - None - } else { - // TODO: Maybe should run through whole list and make sure they're all matching? - match &initial_nodes.get(0).unwrap().addr { - ConnectionAddr::Tcp(_, _) => None, - ConnectionAddr::TcpTls { - host: _, - port: _, - insecure, - } => Some(TlsMode::from_insecure_flag(*insecure)), - _ => None, - } - } - }, - #[cfg(not(feature = "tls"))] - tls: None, + tls: cluster_params.tls, initial_nodes: initial_nodes.to_vec(), + retries: cluster_params.retries, }; - connection.refresh_slots()?; + connection.create_initial_connections()?; Ok(connection) } @@ -195,14 +215,9 @@ impl ClusterConnection { } /// Check that all connections it has are available (`PING` internally). + #[doc(hidden)] pub fn check_connection(&mut self) -> bool { - let mut connections = self.connections.borrow_mut(); - for conn in connections.values_mut() { - if !conn.check_connection() { - return false; - } - } - true + ::check_connection(self) } pub(crate) fn execute_pipeline(&mut self, pipe: &ClusterPipeline) -> RedisResult> { @@ -216,34 +231,13 @@ impl ClusterConnection { /// connection, otherwise a Redis protocol error). When using unix /// sockets the connection is open until writing a command failed with a /// `BrokenPipe` error. - fn create_initial_connections( - initial_nodes: &[ConnectionInfo], - read_from_replicas: bool, - username: Option, - password: Option, - ) -> RedisResult> { - let mut connections = HashMap::with_capacity(initial_nodes.len()); - - for info in initial_nodes.iter() { - let addr = match info.addr { - ConnectionAddr::Tcp(ref host, port) => format!("redis://{}:{}", host, port), - ConnectionAddr::TcpTls { - ref host, - port, - insecure, - } => { - let tls_mode = TlsMode::from_insecure_flag(insecure); - build_connection_string(host, Some(port), Some(tls_mode)) - } - _ => panic!("No reach."), - }; + fn create_initial_connections(&self) -> RedisResult<()> { + let mut connections = HashMap::with_capacity(self.initial_nodes.len()); + + for info in self.initial_nodes.iter() { + let addr = info.addr.to_string(); - if let Ok(mut conn) = connect( - info.clone(), - read_from_replicas, - username.clone(), - password.clone(), - ) { + if let Ok(mut conn) = self.connect(&addr) { if conn.check_connection() { connections.insert(addr, conn); break; @@ -257,25 +251,16 @@ impl ClusterConnection { "It failed to check startup nodes.", ))); } - Ok(connections) + + *self.connections.borrow_mut() = connections; + self.refresh_slots()?; + Ok(()) } // Query a node to discover slot-> master mappings. fn refresh_slots(&self) -> RedisResult<()> { let mut slots = self.slots.borrow_mut(); - *slots = self.create_new_slots(|slot_data| { - let replica = if !self.read_from_replicas || slot_data.replicas().is_empty() { - slot_data.master().to_string() - } else { - slot_data - .replicas() - .choose(&mut thread_rng()) - .unwrap() - .to_string() - }; - - [slot_data.master().to_string(), replica] - })?; + *slots = self.create_new_slots()?; let mut nodes = slots.values().flatten().collect::>(); nodes.sort_unstable(); @@ -292,12 +277,7 @@ impl ClusterConnection { } } - if let Ok(mut conn) = connect( - addr.as_ref(), - self.read_from_replicas, - self.username.clone(), - self.password.clone(), - ) { + if let Ok(mut conn) = self.connect(addr) { if conn.check_connection() { conn.set_read_timeout(*self.read_timeout.borrow()).unwrap(); conn.set_write_timeout(*self.write_timeout.borrow()) @@ -313,10 +293,7 @@ impl ClusterConnection { Ok(()) } - fn create_new_slots(&self, mut get_addr: F) -> RedisResult - where - F: FnMut(&Slot) -> [String; 2], - { + fn create_new_slots(&self) -> RedisResult { let mut connections = self.connections.borrow_mut(); let mut new_slots = None; let mut rng = thread_rng(); @@ -324,7 +301,8 @@ impl ClusterConnection { let mut samples = connections.values_mut().choose_multiple(&mut rng, len); for conn in samples.iter_mut() { - if let Ok(mut slots_data) = get_slots(conn, self.tls) { + let value = conn.req_command(&slot_cmd())?; + if let Ok(mut slots_data) = parse_slots(value, self.tls) { slots_data.sort_by_key(|s| s.start()); let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| { if prev_end != slot_data.start() { @@ -346,14 +324,19 @@ impl ClusterConnection { return Err(RedisError::from(( ErrorKind::ResponseError, "Slot refresh error.", - format!("Lacks the slots >= {}", last_slot), + format!("Lacks the slots >= {last_slot}"), ))); } new_slots = Some( slots_data .iter() - .map(|slot_data| (slot_data.end(), get_addr(slot_data))) + .map(|slot| { + ( + slot.end(), + SlotAddrs::from_slot(slot, self.read_from_replicas), + ) + }) .collect(), ); break; @@ -370,17 +353,34 @@ impl ClusterConnection { } } + fn connect(&self, node: &str) -> RedisResult { + let params = ClusterParams { + password: self.password.clone(), + username: self.username.clone(), + tls: self.tls, + ..Default::default() + }; + let info = get_connection_info(node, params)?; + + let mut conn = C::connect(info, None)?; + if self.read_from_replicas { + // If READONLY is sent to primary nodes, it will have no effect + cmd("READONLY").query(&mut conn)?; + } + Ok(conn) + } + fn get_connection<'a>( &self, - connections: &'a mut HashMap, - route: (u16, usize), - ) -> RedisResult<(String, &'a mut Connection)> { - let (slot, idx) = route; + connections: &'a mut HashMap, + route: &Route, + ) -> RedisResult<(String, &'a mut C)> { let slots = self.slots.borrow(); - if let Some((_, addr)) = slots.range(&slot..).next() { + if let Some((_, slot_addrs)) = slots.range(route.slot()..).next() { + let addr = &slot_addrs.slot_addr(route.slot_addr()); Ok(( - addr[idx].clone(), - self.get_connection_by_addr(connections, &addr[idx])?, + addr.to_string(), + self.get_connection_by_addr(connections, addr)?, )) } else { // try a random node next. This is safe if slots are involved @@ -391,28 +391,67 @@ impl ClusterConnection { fn get_connection_by_addr<'a>( &self, - connections: &'a mut HashMap, + connections: &'a mut HashMap, addr: &str, - ) -> RedisResult<&'a mut Connection> { + ) -> RedisResult<&'a mut C> { if connections.contains_key(addr) { Ok(connections.get_mut(addr).unwrap()) } else { // Create new connection. // TODO: error handling - let conn = connect( - addr, - self.read_from_replicas, - self.username.clone(), - self.password.clone(), - )?; + let conn = self.connect(addr)?; Ok(connections.entry(addr.to_string()).or_insert(conn)) } } + fn get_addr_for_cmd(&self, cmd: &Cmd) -> RedisResult { + let slots = self.slots.borrow(); + + let addr_for_slot = |slot: u16, slot_addr: SlotAddr| -> RedisResult { + let (_, slot_addrs) = slots + .range(&slot..) + .next() + .ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?; + Ok(slot_addrs.slot_addr(&slot_addr).to_string()) + }; + + match RoutingInfo::for_routable(cmd) { + Some(RoutingInfo::Random) => { + let mut rng = thread_rng(); + Ok(addr_for_slot( + rng.gen_range(0..SLOT_SIZE), + SlotAddr::Master, + )?) + } + Some(RoutingInfo::MasterSlot(slot)) => Ok(addr_for_slot(slot, SlotAddr::Master)?), + Some(RoutingInfo::ReplicaSlot(slot)) => Ok(addr_for_slot(slot, SlotAddr::Replica)?), + _ => fail!(UNROUTABLE_ERROR), + } + } + + fn map_cmds_to_nodes(&self, cmds: &[Cmd]) -> RedisResult> { + let mut cmd_map: HashMap = HashMap::new(); + + for (idx, cmd) in cmds.iter().enumerate() { + let addr = self.get_addr_for_cmd(cmd)?; + let nc = cmd_map + .entry(addr.clone()) + .or_insert_with(|| NodeCmd::new(addr)); + nc.indexes.push(idx); + cmd.write_packed_command(&mut nc.pipe); + } + + let mut result = Vec::new(); + for (_, v) in cmd_map.drain() { + result.push(v); + } + Ok(result) + } + fn execute_on_all_nodes(&self, mut func: F) -> RedisResult where T: MergeResults, - F: FnMut(&mut Connection) -> RedisResult, + F: FnMut(&mut C) -> RedisResult, { let mut connections = self.connections.borrow_mut(); let mut results = HashMap::new(); @@ -430,19 +469,19 @@ impl ClusterConnection { where R: ?Sized + Routable, T: MergeResults + std::fmt::Debug, - F: FnMut(&mut Connection) -> RedisResult, + F: FnMut(&mut C) -> RedisResult, { let route = match RoutingInfo::for_routable(cmd) { Some(RoutingInfo::Random) => None, - Some(RoutingInfo::MasterSlot(slot)) => Some((slot, 0)), - Some(RoutingInfo::ReplicaSlot(slot)) => Some((slot, 1)), + Some(RoutingInfo::MasterSlot(slot)) => Some(Route::new(slot, SlotAddr::Master)), + Some(RoutingInfo::ReplicaSlot(slot)) => Some(Route::new(slot, SlotAddr::Replica)), Some(RoutingInfo::AllNodes) | Some(RoutingInfo::AllMasters) => { return self.execute_on_all_nodes(func); } None => fail!(UNROUTABLE_ERROR), }; - let mut retries = 16; + let mut retries = self.retries; let mut excludes = HashSet::new(); let mut redirected = None::; let mut is_asking = false; @@ -451,7 +490,7 @@ impl ClusterConnection { let (addr, rv) = { let mut connections = self.connections.borrow_mut(); let (addr, conn) = if let Some(addr) = redirected.take() { - let conn = self.get_connection_by_addr(&mut *connections, &addr)?; + let conn = self.get_connection_by_addr(&mut connections, &addr)?; if is_asking { // if we are in asking mode we want to feed a single // ASKING command into the connection before what we @@ -461,9 +500,9 @@ impl ClusterConnection { } (addr.to_string(), conn) } else if !excludes.is_empty() || route.is_none() { - get_random_connection(&mut *connections, Some(&excludes)) + get_random_connection(&mut connections, Some(&excludes)) } else { - self.get_connection(&mut *connections, route.unwrap())? + self.get_connection(&mut connections, route.as_ref().unwrap())? }; (addr, func(conn)) }; @@ -471,18 +510,16 @@ impl ClusterConnection { match rv { Ok(rv) => return Ok(rv), Err(err) => { - retries -= 1; if retries == 0 { return Err(err); } + retries -= 1; if err.is_cluster_error() { let kind = err.kind(); if kind == ErrorKind::Ask { - redirected = err - .redirect_node() - .map(|(node, _slot)| build_connection_string(node, None, self.tls)); + redirected = err.redirect_node().map(|(node, _slot)| node.to_string()); is_asking = true; } else if kind == ErrorKind::Moved { // Refresh slots. @@ -490,9 +527,7 @@ impl ClusterConnection { excludes.clear(); // Request again. - redirected = err - .redirect_node() - .map(|(node, _slot)| build_connection_string(node, None, self.tls)); + redirected = err.redirect_node().map(|(node, _slot)| node.to_string()); is_asking = false; continue; } else if kind == ErrorKind::TryAgain || kind == ErrorKind::ClusterDown { @@ -503,17 +538,7 @@ impl ClusterConnection { continue; } } else if *self.auto_reconnect.borrow() && err.is_io_error() { - let new_connections = Self::create_initial_connections( - &self.initial_nodes, - self.read_from_replicas, - self.username.clone(), - self.password.clone(), - )?; - { - let mut connections = self.connections.borrow_mut(); - *connections = new_connections; - } - self.refresh_slots()?; + self.create_initial_connections()?; excludes.clear(); continue; } else { @@ -570,47 +595,6 @@ impl ClusterConnection { Ok(node_cmds) } - fn get_addr_for_cmd(&self, cmd: &Cmd) -> RedisResult { - let slots = self.slots.borrow(); - - let addr_for_slot = |slot: u16, idx: usize| -> RedisResult { - let (_, addr) = slots - .range(&slot..) - .next() - .ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?; - Ok(addr[idx].clone()) - }; - - match RoutingInfo::for_routable(cmd) { - Some(RoutingInfo::Random) => { - let mut rng = thread_rng(); - Ok(addr_for_slot(rng.gen_range(0..SLOT_SIZE) as u16, 0)?) - } - Some(RoutingInfo::MasterSlot(slot)) => Ok(addr_for_slot(slot, 0)?), - Some(RoutingInfo::ReplicaSlot(slot)) => Ok(addr_for_slot(slot, 1)?), - _ => fail!(UNROUTABLE_ERROR), - } - } - - fn map_cmds_to_nodes(&self, cmds: &[Cmd]) -> RedisResult> { - let mut cmd_map: HashMap = HashMap::new(); - - for (idx, cmd) in cmds.iter().enumerate() { - let addr = self.get_addr_for_cmd(cmd)?; - let nc = cmd_map - .entry(addr.clone()) - .or_insert_with(|| NodeCmd::new(addr)); - nc.indexes.push(idx); - cmd.write_packed_command(&mut nc.pipe); - } - - let mut result = Vec::new(); - for (_, v) in cmd_map.drain() { - result.push(v); - } - Ok(result) - } - // Receive from each node, keeping track of which commands need to be retried. fn recv_all_commands( &self, @@ -640,50 +624,7 @@ impl ClusterConnection { } } -trait MergeResults { - fn merge_results(_values: HashMap<&str, Self>) -> Self - where - Self: Sized; -} - -impl MergeResults for Value { - fn merge_results(values: HashMap<&str, Value>) -> Value { - let mut items = vec![]; - for (addr, value) in values.into_iter() { - items.push(Value::Bulk(vec![ - Value::Data(addr.as_bytes().to_vec()), - value, - ])); - } - Value::Bulk(items) - } -} - -impl MergeResults for Vec { - fn merge_results(_values: HashMap<&str, Vec>) -> Vec { - unreachable!("attempted to merge a pipeline. This should not happen"); - } -} - -#[derive(Debug)] -struct NodeCmd { - // The original command indexes - indexes: Vec, - pipe: Vec, - addr: String, -} - -impl NodeCmd { - fn new(a: String) -> NodeCmd { - NodeCmd { - indexes: vec![], - pipe: vec![], - addr: a, - } - } -} - -impl ConnectionLike for ClusterConnection { +impl ConnectionLike for ClusterConnection { fn supports_pipelining(&self) -> bool { false } @@ -734,32 +675,63 @@ impl ConnectionLike for ClusterConnection { } } -fn connect( - info: T, - read_from_replicas: bool, - username: Option, - password: Option, -) -> RedisResult -where - T: std::fmt::Debug, -{ - let mut connection_info = info.into_connection_info()?; - connection_info.redis.username = username; - connection_info.redis.password = password; - let client = super::Client::open(connection_info)?; +trait MergeResults { + fn merge_results(_values: HashMap<&str, Self>) -> Self + where + Self: Sized; +} + +impl MergeResults for Value { + fn merge_results(values: HashMap<&str, Value>) -> Value { + let mut items = vec![]; + for (addr, value) in values.into_iter() { + items.push(Value::Bulk(vec![ + Value::Data(addr.as_bytes().to_vec()), + value, + ])); + } + Value::Bulk(items) + } +} + +impl MergeResults for Vec { + fn merge_results(_values: HashMap<&str, Vec>) -> Vec { + unreachable!("attempted to merge a pipeline. This should not happen"); + } +} + +#[derive(Debug)] +struct NodeCmd { + // The original command indexes + indexes: Vec, + pipe: Vec, + addr: String, +} - let mut con = client.get_connection()?; - if read_from_replicas { - // If READONLY is sent to primary nodes, it will have no effect - cmd("READONLY").query(&mut con)?; +impl NodeCmd { + fn new(a: String) -> NodeCmd { + NodeCmd { + indexes: vec![], + pipe: vec![], + addr: a, + } } - Ok(con) } -fn get_random_connection<'a>( - connections: &'a mut HashMap, +/// TlsMode indicates use or do not use verification of certification. +/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more. +#[derive(Clone, Copy)] +pub enum TlsMode { + /// Secure verify certification. + Secure, + /// Insecure do not verify certification. + Insecure, +} + +fn get_random_connection<'a, C: ConnectionLike + Connect + Sized>( + connections: &'a mut HashMap, excludes: Option<&'a HashSet>, -) -> (String, &'a mut Connection) { +) -> (String, &'a mut C) { let mut rng = thread_rng(); let addr = match excludes { Some(excludes) if excludes.len() < connections.len() => connections @@ -775,16 +747,12 @@ fn get_random_connection<'a>( (addr, con) } -// Get slot data from connection. -fn get_slots(connection: &mut Connection, tls_mode: Option) -> RedisResult> { - let mut cmd = Cmd::new(); - cmd.arg("CLUSTER").arg("SLOTS"); - let value = connection.req_command(&cmd)?; - +// Parse slot data from raw redis value. +pub(crate) fn parse_slots(raw_slot_resp: Value, tls: Option) -> RedisResult> { // Parse response. let mut result = Vec::with_capacity(2); - if let Value::Bulk(items) = value { + if let Value::Bulk(items) = raw_slot_resp { let mut iter = items.into_iter(); while let Some(Value::Bulk(item)) = iter.next() { if item.len() < 3 { @@ -826,7 +794,7 @@ fn get_slots(connection: &mut Connection, tls_mode: Option) -> RedisRes } else { return None; }; - Some(build_connection_string(&ip, Some(port), tls_mode)) + Some(get_connection_addr(ip.into_owned(), port, tls).to_string()) } else { None } @@ -845,16 +813,96 @@ fn get_slots(connection: &mut Connection, tls_mode: Option) -> RedisRes Ok(result) } -fn build_connection_string(host: &str, port: Option, tls_mode: Option) -> String { - let host_port = match port { - Some(port) => format!("{}:{}", host, port), - None => host.to_string(), - }; - match tls_mode { - None => format!("redis://{}", host_port), - Some(TlsMode::Insecure) => { - format!("rediss://{}/#insecure", host_port) +// The node string passed to this function will always be in the format host:port as it is either: +// - Created by calling ConnectionAddr::to_string (unix connections are not supported in cluster mode) +// - Returned from redis via the ASK/MOVED response +pub(crate) fn get_connection_info( + node: &str, + cluster_params: ClusterParams, +) -> RedisResult { + let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string"); + + let (host, port) = node + .rsplit_once(':') + .and_then(|(host, port)| { + Some(host.trim_start_matches('[').trim_end_matches(']')) + .filter(|h| !h.is_empty()) + .zip(u16::from_str(port).ok()) + }) + .ok_or_else(invalid_error)?; + + Ok(ConnectionInfo { + addr: get_connection_addr(host.to_string(), port, cluster_params.tls), + redis: RedisConnectionInfo { + password: cluster_params.password, + username: cluster_params.username, + ..Default::default() + }, + }) +} + +fn get_connection_addr(host: String, port: u16, tls: Option) -> ConnectionAddr { + match tls { + Some(TlsMode::Secure) => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + }, + Some(TlsMode::Insecure) => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + }, + _ => ConnectionAddr::Tcp(host, port), + } +} + +pub(crate) fn slot_cmd() -> Cmd { + let mut cmd = Cmd::new(); + cmd.arg("CLUSTER").arg("SLOTS"); + cmd +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_cluster_node_host_port() { + let cases = vec![ + ( + "127.0.0.1:6379", + ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379u16), + ), + ( + "localhost.localdomain:6379", + ConnectionAddr::Tcp("localhost.localdomain".to_string(), 6379u16), + ), + ( + "dead::cafe:beef:30001", + ConnectionAddr::Tcp("dead::cafe:beef".to_string(), 30001u16), + ), + ( + "[fe80::cafe:beef%en1]:30001", + ConnectionAddr::Tcp("fe80::cafe:beef%en1".to_string(), 30001u16), + ), + ]; + + for (input, expected) in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!(res.unwrap().addr, expected); + } + + let cases = vec![":0", "[]:6379"]; + for input in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!( + res.err(), + Some(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Invalid node string", + ))), + ); } - Some(TlsMode::Secure) => format!("rediss://{}", host_port), } } diff --git a/redis/src/cluster_async/LICENSE b/redis/src/cluster_async/LICENSE new file mode 100644 index 000000000..aaa71a163 --- /dev/null +++ b/redis/src/cluster_async/LICENSE @@ -0,0 +1,7 @@ +Copyright 2019 Atsushi Koge, Markus Westerlind + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs new file mode 100644 index 000000000..dec5cd905 --- /dev/null +++ b/redis/src/cluster_async/mod.rs @@ -0,0 +1,929 @@ +//! This module provides async functionality for Redis Cluster. +//! +//! By default, [`ClusterConnection`] makes use of [`MultiplexedConnection`] and maintains a pool +//! of connections to each node in the cluster. While it generally behaves similarly to +//! the sync cluster module, certain commands do not route identically, due most notably to +//! a current lack of support for routing commands to multiple nodes. +//! +//! Also note that pubsub functionality is not currently provided by this module. +//! +//! # Example +//! ```rust,no_run +//! use redis::cluster::ClusterClient; +//! use redis::AsyncCommands; +//! +//! async fn fetch_an_integer() -> String { +//! let nodes = vec!["redis://127.0.0.1/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_async_connection().await.unwrap(); +//! let _: () = connection.set("test", "test_data").await.unwrap(); +//! let rv: String = connection.get("test").await.unwrap(); +//! return rv; +//! } +//! ``` +use std::{ + collections::{HashMap, HashSet}, + fmt, io, + iter::Iterator, + marker::Unpin, + mem, + pin::Pin, + sync::Arc, + task::{self, Poll}, + time::Duration, +}; + +use crate::{ + aio::{ConnectionLike, MultiplexedConnection}, + cluster::{get_connection_info, parse_slots, slot_cmd}, + cluster_client::ClusterParams, + cluster_routing::{Route, RoutingInfo, Slot, SlotAddr, SlotAddrs, SlotMap}, + Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult, + Value, +}; + +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use crate::aio::{async_std::AsyncStd, RedisRuntime}; +use futures::{ + future::{self, BoxFuture}, + prelude::*, + ready, stream, +}; +use log::trace; +use pin_project_lite::pin_project; +use rand::seq::IteratorRandom; +use rand::thread_rng; +use tokio::sync::{mpsc, oneshot}; + +const SLOT_SIZE: usize = 16384; + +/// This represents an async Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +#[derive(Clone)] +pub struct ClusterConnection(mpsc::Sender>); + +impl ClusterConnection +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + pub(crate) async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + ) -> RedisResult> { + ClusterConnInner::new(initial_nodes, cluster_params) + .await + .map(|inner| { + let (tx, mut rx) = mpsc::channel::>(100); + let stream = async move { + let _ = stream::poll_fn(move |cx| rx.poll_recv(cx)) + .map(Ok) + .forward(inner) + .await; + }; + #[cfg(feature = "tokio-comp")] + tokio::spawn(stream); + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + AsyncStd::spawn(stream); + + ClusterConnection(tx) + }) + } +} + +type ConnectionFuture = future::Shared>; +type ConnectionMap = HashMap>; + +struct ClusterConnInner { + connections: ConnectionMap, + slots: SlotMap, + state: ConnectionState, + #[allow(clippy::complexity)] + in_flight_requests: stream::FuturesUnordered< + Pin)>, Response, C>>>, + >, + refresh_error: Option, + pending_requests: Vec>, + cluster_params: ClusterParams, +} + +#[derive(Clone)] +enum CmdArg { + Cmd { + cmd: Arc, + func: fn(C, Arc) -> RedisFuture<'static, Response>, + }, + Pipeline { + pipeline: Arc, + offset: usize, + count: usize, + func: fn(C, Arc, usize, usize) -> RedisFuture<'static, Response>, + }, +} + +impl CmdArg { + fn exec(&self, con: C) -> RedisFuture<'static, Response> { + match self { + Self::Cmd { cmd, func } => func(con, cmd.clone()), + Self::Pipeline { + pipeline, + offset, + count, + func, + } => func(con, pipeline.clone(), *offset, *count), + } + } + + fn route(&self) -> Option { + fn route_for_command(cmd: &Cmd) -> Option { + match RoutingInfo::for_routable(cmd) { + Some(RoutingInfo::Random) => None, + Some(RoutingInfo::MasterSlot(slot)) => Some(Route::new(slot, SlotAddr::Master)), + Some(RoutingInfo::ReplicaSlot(slot)) => Some(Route::new(slot, SlotAddr::Replica)), + Some(RoutingInfo::AllNodes) | Some(RoutingInfo::AllMasters) => None, + _ => None, + } + } + + match self { + Self::Cmd { ref cmd, .. } => route_for_command(cmd), + Self::Pipeline { ref pipeline, .. } => { + let mut iter = pipeline.cmd_iter(); + let slot = iter.next().map(route_for_command)?; + for cmd in iter { + if slot != route_for_command(cmd) { + return None; + } + } + slot + } + } + } +} + +enum Response { + Single(Value), + Multiple(Vec), +} + +struct Message { + cmd: CmdArg, + sender: oneshot::Sender>, +} + +type RecoverFuture = + BoxFuture<'static, Result<(SlotMap, ConnectionMap), (RedisError, ConnectionMap)>>; + +enum ConnectionState { + PollComplete, + Recover(RecoverFuture), +} + +impl fmt::Debug for ConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + ConnectionState::PollComplete => "PollComplete", + ConnectionState::Recover(_) => "Recover", + } + ) + } +} + +struct RequestInfo { + cmd: CmdArg, + route: Option, + excludes: HashSet, +} + +pin_project! { + #[project = RequestStateProj] + enum RequestState { + None, + Future { + #[pin] + future: F, + }, + Sleep { + #[pin] + sleep: BoxFuture<'static, ()>, + }, + } +} + +struct PendingRequest { + retry: u32, + sender: oneshot::Sender>, + info: RequestInfo, +} + +pin_project! { + struct Request { + max_retries: u32, + request: Option>, + #[pin] + future: RequestState, + } +} + +#[must_use] +enum Next { + TryNewConnection { + request: PendingRequest, + error: Option, + }, + Err { + request: PendingRequest, + error: RedisError, + }, + Done, +} + +impl Future for Request +where + F: Future)>, + C: ConnectionLike, +{ + type Output = Next; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll { + let mut this = self.as_mut().project(); + if this.request.is_none() { + return Poll::Ready(Next::Done); + } + let future = match this.future.as_mut().project() { + RequestStateProj::Future { future } => future, + RequestStateProj::Sleep { sleep } => { + ready!(sleep.poll(cx)); + return Next::TryNewConnection { + request: self.project().request.take().unwrap(), + error: None, + } + .into(); + } + _ => panic!("Request future must be Some"), + }; + match ready!(future.poll(cx)) { + (_, Ok(item)) => { + trace!("Ok"); + self.respond(Ok(item)); + Next::Done.into() + } + (addr, Err(err)) => { + trace!("Request error {}", err); + + let request = this.request.as_mut().unwrap(); + + if request.retry >= *this.max_retries { + self.respond(Err(err)); + return Next::Done.into(); + } + request.retry = request.retry.saturating_add(1); + + if let Some(error_code) = err.code() { + if error_code == "MOVED" || error_code == "ASK" { + // Refresh slots and request again. + request.info.excludes.clear(); + return Next::Err { + request: this.request.take().unwrap(), + error: err, + } + .into(); + } else if error_code == "TRYAGAIN" || error_code == "CLUSTERDOWN" { + // Sleep and retry. + let sleep_duration = + Duration::from_millis(2u64.pow(request.retry.clamp(7, 16)) * 10); + request.info.excludes.clear(); + this.future.set(RequestState::Sleep { + #[cfg(feature = "tokio-comp")] + sleep: Box::pin(tokio::time::sleep(sleep_duration)), + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + sleep: Box::pin(async_std::task::sleep(sleep_duration)), + }); + return self.poll(cx); + } + } + + request.info.excludes.insert(addr); + + Next::TryNewConnection { + request: this.request.take().unwrap(), + error: Some(err), + } + .into() + } + } + } +} + +impl Request +where + F: Future)>, + C: ConnectionLike, +{ + fn respond(self: Pin<&mut Self>, msg: RedisResult) { + // If `send` errors the receiver has dropped and thus does not care about the message + let _ = self + .project() + .request + .take() + .expect("Result should only be sent once") + .sender + .send(msg); + } +} + +impl ClusterConnInner +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + ) -> RedisResult { + let connections = + Self::create_initial_connections(initial_nodes, cluster_params.clone()).await?; + let mut connection = ClusterConnInner { + connections, + slots: Default::default(), + in_flight_requests: Default::default(), + refresh_error: None, + pending_requests: Vec::new(), + state: ConnectionState::PollComplete, + cluster_params, + }; + let (slots, connections) = connection.refresh_slots().await.map_err(|(err, _)| err)?; + connection.slots = slots; + connection.connections = connections; + Ok(connection) + } + + async fn create_initial_connections( + initial_nodes: &[ConnectionInfo], + params: ClusterParams, + ) -> RedisResult> { + let connections = stream::iter(initial_nodes.iter().cloned()) + .map(|info| { + let params = params.clone(); + async move { + let addr = info.addr.to_string(); + let result = connect_and_check(&addr, params).await; + match result { + Ok(conn) => Some((addr, async { conn }.boxed().shared())), + Err(e) => { + trace!("Failed to connect to initial node: {:?}", e); + None + } + } + } + }) + .buffer_unordered(initial_nodes.len()) + .fold( + HashMap::with_capacity(initial_nodes.len()), + |mut connections: ConnectionMap, conn| async move { + connections.extend(conn); + connections + }, + ) + .await; + if connections.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "Failed to create initial connections", + ))); + } + Ok(connections) + } + + // Query a node to discover slot-> master mappings. + fn refresh_slots( + &mut self, + ) -> impl Future), (RedisError, ConnectionMap)>> + { + let mut connections = mem::take(&mut self.connections); + let cluster_params = self.cluster_params.clone(); + + async move { + let mut result = Ok(SlotMap::new()); + for (_, conn) in connections.iter_mut() { + let mut conn = conn.clone().await; + let value = match conn.req_packed_command(&slot_cmd()).await { + Ok(value) => value, + Err(err) => { + result = Err(err); + continue; + } + }; + match parse_slots(value, cluster_params.tls) + .and_then(|v| Self::build_slot_map(v, cluster_params.read_from_replicas)) + { + Ok(s) => { + result = Ok(s); + break; + } + Err(err) => result = Err(err), + } + } + let slots = match result { + Ok(slots) => slots, + Err(err) => return Err((err, connections)), + }; + + let mut nodes = slots.values().flatten().collect::>(); + nodes.sort_unstable(); + nodes.dedup(); + + // Remove dead connections and connect to new nodes if necessary + let mut new_connections = HashMap::with_capacity(slots.len()); + + for addr in nodes { + if !new_connections.contains_key(addr) { + let new_connection = if let Some(conn) = connections.remove(addr) { + let mut conn = conn.await; + match check_connection(&mut conn).await { + Ok(_) => Some((addr.to_string(), conn)), + Err(_) => match connect_and_check(addr, cluster_params.clone()).await { + Ok(conn) => Some((addr.to_string(), conn)), + Err(_) => None, + }, + } + } else { + match connect_and_check(addr, cluster_params.clone()).await { + Ok(conn) => Some((addr.to_string(), conn)), + Err(_) => None, + } + }; + if let Some((addr, new_connection)) = new_connection { + new_connections.insert(addr, async { new_connection }.boxed().shared()); + } + } + } + + Ok((slots, new_connections)) + } + } + + fn build_slot_map(mut slots_data: Vec, read_from_replicas: bool) -> RedisResult { + slots_data.sort_by_key(|slot_data| slot_data.start()); + let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| { + if prev_end != slot_data.start() { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error.", + format!( + "Received overlapping slots {} and {}..{}", + prev_end, + slot_data.start(), + slot_data.end() + ), + ))); + } + Ok(slot_data.end() + 1) + })?; + + if usize::from(last_slot) != SLOT_SIZE { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error.", + format!("Lacks the slots >= {last_slot}"), + ))); + } + let slot_map = slots_data + .iter() + .map(|slot| (slot.end(), SlotAddrs::from_slot(slot, read_from_replicas))) + .collect(); + trace!("{:?}", slot_map); + Ok(slot_map) + } + + fn get_connection(&mut self, route: &Route) -> (String, ConnectionFuture) { + if let Some((_, node_addrs)) = self.slots.range(&route.slot()..).next() { + let addr = node_addrs.slot_addr(route.slot_addr()).to_string(); + if let Some(conn) = self.connections.get(&addr) { + return (addr, conn.clone()); + } + + // Create new connection. + // + let (_, random_conn) = get_random_connection(&self.connections, None); // TODO Only do this lookup if the first check fails + let connection_future = { + let addr = addr.clone(); + let params = self.cluster_params.clone(); + async move { + match connect_and_check(&addr, params).await { + Ok(conn) => conn, + Err(_) => random_conn.await, + } + } + } + .boxed() + .shared(); + self.connections + .insert(addr.clone(), connection_future.clone()); + (addr, connection_future) + } else { + // Return a random connection + get_random_connection(&self.connections, None) + } + } + + fn try_request( + &mut self, + info: &RequestInfo, + ) -> impl Future)> { + // TODO remove clone by changing the ConnectionLike trait + let cmd = info.cmd.clone(); + let (addr, conn) = if !info.excludes.is_empty() || info.route.is_none() { + get_random_connection(&self.connections, Some(&info.excludes)) + } else { + self.get_connection(info.route.as_ref().unwrap()) + }; + async move { + let conn = conn.await; + let result = cmd.exec(conn).await; + (addr, result) + } + } + + fn poll_recover( + &mut self, + cx: &mut task::Context<'_>, + mut future: RecoverFuture, + ) -> Poll> { + match future.as_mut().poll(cx) { + Poll::Ready(Ok((slots, connections))) => { + trace!("Recovered with {} connections!", connections.len()); + self.slots = slots; + self.connections = connections; + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + Poll::Pending => { + self.state = ConnectionState::Recover(future); + trace!("Recover not ready"); + Poll::Pending + } + Poll::Ready(Err((err, connections))) => { + self.connections = connections; + self.state = ConnectionState::Recover(Box::pin(self.refresh_slots())); + Poll::Ready(Err(err)) + } + } + } + + fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll> { + let mut connection_error = None; + + if !self.pending_requests.is_empty() { + let mut pending_requests = mem::take(&mut self.pending_requests); + for request in pending_requests.drain(..) { + // Drop the request if noone is waiting for a response to free up resources for + // requests callers care about (load shedding). It will be ambigous whether the + // request actually goes through regardless. + if request.sender.is_closed() { + continue; + } + + let future = self.try_request(&request.info); + self.in_flight_requests.push(Box::pin(Request { + max_retries: self.cluster_params.retries, + request: Some(request), + future: RequestState::Future { + future: future.boxed(), + }, + })); + } + self.pending_requests = pending_requests; + } + + loop { + let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) { + Poll::Ready(Some(result)) => result, + Poll::Ready(None) | Poll::Pending => break, + }; + let self_ = &mut *self; + match result { + Next::Done => {} + Next::TryNewConnection { request, error } => { + if let Some(error) = error { + if request.info.excludes.len() >= self_.connections.len() { + let _ = request.sender.send(Err(error)); + continue; + } + } + let future = self.try_request(&request.info); + self.in_flight_requests.push(Box::pin(Request { + max_retries: self.cluster_params.retries, + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::Err { request, error } => { + connection_error = Some(error); + self.pending_requests.push(request); + } + } + } + + if let Some(err) = connection_error { + Poll::Ready(Err(err)) + } else if self.in_flight_requests.is_empty() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn send_refresh_error(&mut self) { + if self.refresh_error.is_some() { + if let Some(mut request) = Pin::new(&mut self.in_flight_requests) + .iter_pin_mut() + .find(|request| request.request.is_some()) + { + (*request) + .as_mut() + .respond(Err(self.refresh_error.take().unwrap())); + } else if let Some(request) = self.pending_requests.pop() { + let _ = request.sender.send(Err(self.refresh_error.take().unwrap())); + } + } + } +} + +impl Sink> for ClusterConnInner +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + type Error = (); + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + match mem::replace(&mut self.state, ConnectionState::PollComplete) { + ConnectionState::PollComplete => Poll::Ready(Ok(())), + ConnectionState::Recover(future) => { + match ready!(self.as_mut().poll_recover(cx, future)) { + Ok(()) => Poll::Ready(Ok(())), + Err(err) => { + // We failed to reconnect, while we will try again we will report the + // error if we can to avoid getting trapped in an infinite loop of + // trying to reconnect + if let Some(mut request) = Pin::new(&mut self.in_flight_requests) + .iter_pin_mut() + .find(|request| request.request.is_some()) + { + (*request).as_mut().respond(Err(err)); + } else { + self.refresh_error = Some(err); + } + Poll::Ready(Ok(())) + } + } + } + } + } + + fn start_send(mut self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { + trace!("start_send"); + let Message { cmd, sender } = msg; + + let excludes = HashSet::new(); + let slot = cmd.route(); + + let info = RequestInfo { + cmd, + route: slot, + excludes, + }; + + self.pending_requests.push(PendingRequest { + retry: 0, + sender, + info, + }); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + trace!("poll_complete: {:?}", self.state); + loop { + self.send_refresh_error(); + + match mem::replace(&mut self.state, ConnectionState::PollComplete) { + ConnectionState::Recover(future) => { + match ready!(self.as_mut().poll_recover(cx, future)) { + Ok(()) => (), + Err(err) => { + // We failed to reconnect, while we will try again we will report the + // error if we can to avoid getting trapped in an infinite loop of + // trying to reconnect + self.refresh_error = Some(err); + + // Give other tasks a chance to progress before we try to recover + // again. Since the future may not have registered a wake up we do so + // now so the task is not forgotten + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + ConnectionState::PollComplete => match ready!(self.poll_complete(cx)) { + Ok(()) => return Poll::Ready(Ok(())), + Err(err) => { + trace!("Recovering {}", err); + self.state = ConnectionState::Recover(Box::pin(self.refresh_slots())); + } + }, + } + } + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // Try to drive any in flight requests to completion + match self.poll_complete(cx) { + Poll::Ready(result) => { + result.map_err(|_| ())?; + } + Poll::Pending => (), + }; + // If we no longer have any requests in flight we are done (skips any reconnection + // attempts) + if self.in_flight_requests.is_empty() { + return Poll::Ready(Ok(())); + } + + self.poll_flush(cx) + } +} + +impl ConnectionLike for ClusterConnection +where + C: ConnectionLike + Send + 'static, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + trace!("req_packed_command"); + let (sender, receiver) = oneshot::channel(); + Box::pin(async move { + self.0 + .send(Message { + cmd: CmdArg::Cmd { + cmd: Arc::new(cmd.clone()), // TODO Remove this clone? + func: |mut conn, cmd| { + Box::pin(async move { + conn.req_packed_command(&cmd).await.map(Response::Single) + }) + }, + }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + }) + }) + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + let (sender, receiver) = oneshot::channel(); + Box::pin(async move { + self.0 + .send(Message { + cmd: CmdArg::Pipeline { + pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone? + offset, + count, + func: |mut conn, pipeline, offset, count| { + Box::pin(async move { + conn.req_packed_commands(&pipeline, offset, count) + .await + .map(Response::Multiple) + }) + }, + }, + sender, + }) + .await + .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + }) + .map(|response| match response { + Response::Multiple(values) => values, + Response::Single(_) => unreachable!(), + }) + }) + } + + fn get_db(&self) -> i64 { + 0 + } +} +/// Implements the process of connecting to a Redis server +/// and obtaining a connection handle. +pub trait Connect: Sized { + /// Connect to a node, returning handle for command execution. + fn connect<'a, T>(info: T) -> RedisFuture<'a, Self> + where + T: IntoConnectionInfo + Send + 'a; +} + +impl Connect for MultiplexedConnection { + fn connect<'a, T>(info: T) -> RedisFuture<'a, MultiplexedConnection> + where + T: IntoConnectionInfo + Send + 'a, + { + async move { + let connection_info = info.into_connection_info()?; + let client = crate::Client::open(connection_info)?; + + #[cfg(feature = "tokio-comp")] + return client.get_multiplexed_tokio_connection().await; + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + return client.get_multiplexed_async_std_connection().await; + } + .boxed() + } +} + +async fn connect_and_check(node: &str, params: ClusterParams) -> RedisResult +where + C: ConnectionLike + Connect + Send + 'static, +{ + let read_from_replicas = params.read_from_replicas; + let info = get_connection_info(node, params)?; + let mut conn = C::connect(info).await?; + check_connection(&mut conn).await?; + if read_from_replicas { + // If READONLY is sent to primary nodes, it will have no effect + crate::cmd("READONLY").query_async(&mut conn).await?; + } + Ok(conn) +} + +async fn check_connection(conn: &mut C) -> RedisResult<()> +where + C: ConnectionLike + Send + 'static, +{ + let mut cmd = Cmd::new(); + cmd.arg("PING"); + cmd.query_async::<_, String>(conn).await?; + Ok(()) +} + +fn get_random_connection<'a, C>( + connections: &'a ConnectionMap, + excludes: Option<&'a HashSet>, +) -> (String, ConnectionFuture) +where + C: Clone, +{ + debug_assert!(!connections.is_empty()); + + let mut rng = thread_rng(); + let sample = match excludes { + Some(excludes) if excludes.len() < connections.len() => { + let target_keys = connections.keys().filter(|key| !excludes.contains(*key)); + target_keys.choose(&mut rng) + } + _ => connections.keys().choose(&mut rng), + }; + + let addr = sample.expect("No targets to choose from"); + (addr.to_string(), connections.get(addr).unwrap().clone()) +} diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index f5815c885..6f68c5b36 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -1,6 +1,23 @@ -use crate::cluster::ClusterConnection; use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo}; use crate::types::{ErrorKind, RedisError, RedisResult}; +use crate::{cluster, cluster::TlsMode}; + +#[cfg(feature = "cluster-async")] +use crate::cluster_async; + +const DEFAULT_RETRIES: u32 = 16; + +/// Parameters specific to builder, so that +/// builder parameters may have different types +/// than final ClusterParams +#[derive(Default)] +struct BuilderParams { + password: Option, + username: Option, + read_from_replicas: bool, + tls: Option, + retries: Option, +} /// Redis cluster specific parameters. #[derive(Default, Clone)] @@ -8,16 +25,33 @@ pub(crate) struct ClusterParams { pub(crate) password: Option, pub(crate) username: Option, pub(crate) read_from_replicas: bool, + /// tls indicates tls behavior of connections. + /// When Some(TlsMode), connections use tls and verify certification depends on TlsMode. + /// When None, connections do not use tls. + pub(crate) tls: Option, + pub(crate) retries: u32, +} + +impl From for ClusterParams { + fn from(value: BuilderParams) -> Self { + Self { + password: value.password, + username: value.username, + read_from_replicas: value.read_from_replicas, + tls: value.tls, + retries: value.retries.unwrap_or(DEFAULT_RETRIES), + } + } } /// Used to configure and build a [`ClusterClient`]. pub struct ClusterClientBuilder { initial_nodes: RedisResult>, - cluster_params: ClusterParams, + builder_params: BuilderParams, } impl ClusterClientBuilder { - /// Creates a new `ClusterClientBuilder` with the the provided initial_nodes. + /// Creates a new `ClusterClientBuilder` with the provided initial_nodes. /// /// This is the same as `ClusterClient::builder(initial_nodes)`. pub fn new(initial_nodes: Vec) -> ClusterClientBuilder { @@ -26,11 +60,11 @@ impl ClusterClientBuilder { .into_iter() .map(|x| x.into_connection_info()) .collect(), - cluster_params: ClusterParams::default(), + builder_params: Default::default(), } } - /// Creates a new [`ClusterClient`] with the parameters. + /// Creates a new [`ClusterClient`] from the parameters. /// /// This does not create connections to the Redis Cluster, but only performs some basic checks /// on the initial nodes' URLs and passwords/usernames. @@ -52,7 +86,7 @@ impl ClusterClientBuilder { } }; - let mut cluster_params = self.cluster_params; + let mut cluster_params: ClusterParams = self.builder_params.into(); let password = if cluster_params.password.is_none() { cluster_params.password = first_node.redis.password.clone(); &cluster_params.password @@ -65,6 +99,19 @@ impl ClusterClientBuilder { } else { &None }; + if cluster_params.tls.is_none() { + cluster_params.tls = match first_node.addr { + ConnectionAddr::TcpTls { + host: _, + port: _, + insecure, + } => Some(match insecure { + false => TlsMode::Secure, + true => TlsMode::Insecure, + }), + _ => None, + }; + } let mut nodes = Vec::with_capacity(initial_nodes.len()); for node in initial_nodes { @@ -96,24 +143,39 @@ impl ClusterClientBuilder { }) } - /// Sets password for new ClusterClient. + /// Sets password for the new ClusterClient. pub fn password(mut self, password: String) -> ClusterClientBuilder { - self.cluster_params.password = Some(password); + self.builder_params.password = Some(password); self } - /// Sets username for new ClusterClient. + /// Sets username for the new ClusterClient. pub fn username(mut self, username: String) -> ClusterClientBuilder { - self.cluster_params.username = Some(username); + self.builder_params.username = Some(username); + self + } + + /// Sets number of retries for the new ClusterClient. + pub fn retries(mut self, retries: u32) -> ClusterClientBuilder { + self.builder_params.retries = Some(retries); + self + } + + /// Sets TLS mode for the new ClusterClient. + /// + /// It is extracted from the first node of initial_nodes if not set. + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + pub fn tls(mut self, tls: TlsMode) -> ClusterClientBuilder { + self.builder_params.tls = Some(tls); self } - /// Enables read from replicas for new ClusterClient (default is false). + /// Enables reading from replicas for all new connections (default is disabled). /// - /// If True, then read queries will go to the replica nodes & write queries will go to the + /// If enabled, then read queries will go to the replica nodes & write queries will go to the /// primary nodes. If there are no replica nodes, then all queries will go to the primary nodes. pub fn read_from_replicas(mut self) -> ClusterClientBuilder { - self.cluster_params.read_from_replicas = true; + self.builder_params.read_from_replicas = true; self } @@ -126,12 +188,12 @@ impl ClusterClientBuilder { /// Use `read_from_replicas()`. #[deprecated(since = "0.22.0", note = "Use read_from_replicas()")] pub fn readonly(mut self, read_from_replicas: bool) -> ClusterClientBuilder { - self.cluster_params.read_from_replicas = read_from_replicas; + self.builder_params.read_from_replicas = read_from_replicas; self } } -/// This is a Redis cluster client. +/// This is a Redis Cluster client. #[derive(Clone)] pub struct ClusterClient { initial_nodes: Vec, @@ -149,28 +211,66 @@ impl ClusterClient { /// Upon failure to parse initial nodes or if the initial nodes have different passwords or /// usernames, an error is returned. pub fn new(initial_nodes: Vec) -> RedisResult { - ClusterClientBuilder::new(initial_nodes).build() + Self::builder(initial_nodes).build() } - /// Creates a [`ClusterClientBuilder`] with the the provided initial_nodes. + /// Creates a [`ClusterClientBuilder`] with the provided initial_nodes. pub fn builder(initial_nodes: Vec) -> ClusterClientBuilder { ClusterClientBuilder::new(initial_nodes) } - /// Creates new connections to Redis Cluster nodes and return a - /// [`ClusterConnection`]. + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + pub fn get_connection(&self) -> RedisResult { + cluster::ClusterConnection::new(self.cluster_params.clone(), self.initial_nodes.clone()) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster_async::ClusterConnection`]. /// /// # Errors /// /// An error is returned if there is a failure while creating connections or slots. - pub fn get_connection(&self) -> RedisResult { - ClusterConnection::new(self.cluster_params.clone(), self.initial_nodes.clone()) + #[cfg(feature = "cluster-async")] + pub async fn get_async_connection(&self) -> RedisResult { + cluster_async::ClusterConnection::new(&self.initial_nodes, self.cluster_params.clone()) + .await + } + + #[doc(hidden)] + pub fn get_generic_connection(&self) -> RedisResult> + where + C: crate::ConnectionLike + crate::cluster::Connect + Send, + { + cluster::ClusterConnection::new(self.cluster_params.clone(), self.initial_nodes.clone()) + } + + #[doc(hidden)] + #[cfg(feature = "cluster-async")] + pub async fn get_async_generic_connection( + &self, + ) -> RedisResult> + where + C: crate::aio::ConnectionLike + + cluster_async::Connect + + Clone + + Send + + Sync + + Unpin + + 'static, + { + cluster_async::ClusterConnection::new(&self.initial_nodes, self.cluster_params.clone()) + .await } /// Use `new()`. #[deprecated(since = "0.22.0", note = "Use new()")] pub fn open(initial_nodes: Vec) -> RedisResult { - ClusterClient::new(initial_nodes) + Self::new(initial_nodes) } } diff --git a/redis/src/cluster_pipeline.rs b/redis/src/cluster_pipeline.rs index 920d6962f..14f4fd929 100644 --- a/redis/src/cluster_pipeline.rs +++ b/redis/src/cluster_pipeline.rs @@ -113,10 +113,7 @@ impl ClusterPipeline { fail!(( UNROUTABLE_ERROR.0, UNROUTABLE_ERROR.1, - format!( - "Command '{}' can't be executed in a cluster pipeline.", - cmd_name - ) + format!("Command '{cmd_name}' can't be executed in a cluster pipeline.") )) } } diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index c8a9c59b2..1d1e7797d 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -1,11 +1,19 @@ +use std::collections::BTreeMap; use std::iter::Iterator; +use rand::seq::SliceRandom; +use rand::thread_rng; + use crate::cmd::{Arg, Cmd}; use crate::commands::is_readonly_cmd; use crate::types::Value; pub(crate) const SLOT_SIZE: u16 = 16384; +fn slot(key: &[u8]) -> u16 { + crc16::State::::calculate(key) % SLOT_SIZE +} + #[derive(Debug, Clone, Copy, PartialEq)] pub(crate) enum RoutingInfo { AllNodes, @@ -36,33 +44,33 @@ impl RoutingInfo { if key_count == 0 { Some(RoutingInfo::Random) } else { - r.arg_idx(3).and_then(|key| RoutingInfo::for_key(cmd, key)) + r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)) } } - b"XGROUP" | b"XINFO" => r.arg_idx(2).and_then(|key| RoutingInfo::for_key(cmd, key)), + b"XGROUP" | b"XINFO" => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)), b"XREAD" | b"XREADGROUP" => { let streams_position = r.position(b"STREAMS")?; r.arg_idx(streams_position + 1) - .and_then(|key| RoutingInfo::for_key(cmd, key)) + .map(|key| RoutingInfo::for_key(cmd, key)) } _ => match r.arg_idx(1) { - Some(key) => RoutingInfo::for_key(cmd, key), + Some(key) => Some(RoutingInfo::for_key(cmd, key)), None => Some(RoutingInfo::Random), }, } } - pub fn for_key(cmd: &[u8], key: &[u8]) -> Option { + pub fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo { let key = match get_hashtag(key) { Some(tag) => tag, None => key, }; - let slot = crc16::State::::calculate(key) % SLOT_SIZE; + let slot = slot(key); if is_readonly_cmd(cmd) { - Some(RoutingInfo::ReplicaSlot(slot)) + RoutingInfo::ReplicaSlot(slot) } else { - Some(RoutingInfo::MasterSlot(slot)) + RoutingInfo::MasterSlot(slot) } } } @@ -151,6 +159,78 @@ impl Slot { } } +#[derive(Eq, PartialEq)] +pub(crate) enum SlotAddr { + Master, + Replica, +} + +/// This is just a simplified version of [`Slot`], +/// which stores only the master and [optional] replica +/// to avoid the need to choose a replica each time +/// a command is executed +#[derive(Debug)] +pub(crate) struct SlotAddrs([String; 2]); + +impl SlotAddrs { + pub(crate) fn new(master_node: String, replica_node: Option) -> Self { + let replica = replica_node.unwrap_or_else(|| master_node.clone()); + Self([master_node, replica]) + } + + pub(crate) fn slot_addr(&self, slot_addr: &SlotAddr) -> &str { + match slot_addr { + SlotAddr::Master => &self.0[0], + SlotAddr::Replica => &self.0[1], + } + } + + pub(crate) fn from_slot(slot: &Slot, read_from_replicas: bool) -> Self { + let replica = if !read_from_replicas || slot.replicas().is_empty() { + None + } else { + Some( + slot.replicas() + .choose(&mut thread_rng()) + .unwrap() + .to_string(), + ) + }; + + SlotAddrs::new(slot.master().to_string(), replica) + } +} + +impl<'a> IntoIterator for &'a SlotAddrs { + type Item = &'a String; + type IntoIter = std::slice::Iter<'a, String>; + + fn into_iter(self) -> std::slice::Iter<'a, String> { + self.0.iter() + } +} + +pub(crate) type SlotMap = BTreeMap; + +/// Defines the slot and the [`SlotAddr`] to which +/// a command should be sent +#[derive(Eq, PartialEq)] +pub(crate) struct Route(u16, SlotAddr); + +impl Route { + pub(crate) fn new(slot: u16, slot_addr: SlotAddr) -> Self { + Self(slot, slot_addr) + } + + pub(crate) fn slot(&self) -> u16 { + self.0 + } + + pub(crate) fn slot_addr(&self) -> &SlotAddr { + &self.1 + } +} + fn get_hashtag(key: &[u8]) -> Option<&[u8]> { let open = key.iter().position(|v| *v == b'{'); let open = match open { @@ -174,7 +254,7 @@ fn get_hashtag(key: &[u8]) -> Option<&[u8]> { #[cfg(test)] mod tests { - use super::{get_hashtag, RoutingInfo}; + use super::{get_hashtag, slot, RoutingInfo}; use crate::{cmd, parser::parse_redis_value}; #[test] @@ -257,5 +337,120 @@ mod tests { RoutingInfo::for_routable(&cmd).unwrap(), ); } + + // Assert expected RoutingInfo explicitly: + + for cmd in vec![cmd("FLUSHALL"), cmd("FLUSHDB"), cmd("SCRIPT")] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + Some(RoutingInfo::AllMasters) + ); + } + + for cmd in vec![ + cmd("ECHO"), + cmd("CONFIG"), + cmd("CLIENT"), + cmd("SLOWLOG"), + cmd("DBSIZE"), + cmd("LASTSAVE"), + cmd("PING"), + cmd("INFO"), + cmd("BGREWRITEAOF"), + cmd("BGSAVE"), + cmd("CLIENT LIST"), + cmd("SAVE"), + cmd("TIME"), + cmd("KEYS"), + ] { + assert_eq!(RoutingInfo::for_routable(&cmd), Some(RoutingInfo::AllNodes)); + } + + for cmd in vec![ + cmd("SCAN"), + cmd("CLIENT SETNAME"), + cmd("SHUTDOWN"), + cmd("SLAVEOF"), + cmd("REPLICAOF"), + cmd("SCRIPT KILL"), + cmd("MOVE"), + cmd("BITOP"), + ] { + assert_eq!(RoutingInfo::for_routable(&cmd), None,); + } + + for cmd in vec![ + cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0), + cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0), + ] { + assert_eq!(RoutingInfo::for_routable(cmd), Some(RoutingInfo::Random)); + } + + for (cmd, expected) in vec![ + ( + cmd("EVAL") + .arg(r#"redis.call("GET, KEYS[1]");"#) + .arg(1) + .arg("foo"), + Some(RoutingInfo::MasterSlot(slot(b"foo"))), + ), + ( + cmd("XGROUP") + .arg("CREATE") + .arg("mystream") + .arg("workers") + .arg("$") + .arg("MKSTREAM"), + Some(RoutingInfo::MasterSlot(slot(b"mystream"))), + ), + ( + cmd("XINFO").arg("GROUPS").arg("foo"), + Some(RoutingInfo::ReplicaSlot(slot(b"foo"))), + ), + ( + cmd("XREADGROUP") + .arg("GROUP") + .arg("wkrs") + .arg("consmrs") + .arg("STREAMS") + .arg("mystream"), + Some(RoutingInfo::MasterSlot(slot(b"mystream"))), + ), + ( + cmd("XREAD") + .arg("COUNT") + .arg("2") + .arg("STREAMS") + .arg("mystream") + .arg("writers") + .arg("0-0") + .arg("0-0"), + Some(RoutingInfo::ReplicaSlot(slot(b"mystream"))), + ), + ] { + assert_eq!(RoutingInfo::for_routable(cmd), expected,); + } + } + + #[test] + fn test_slot_for_packed_cmd() { + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10, + 244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10 + ]).unwrap()), Some(RoutingInfo::ReplicaSlot(slot)) if slot == 964)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241, + 197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::MasterSlot(slot)) if slot == 8352)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233, + 247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::MasterSlot(slot)) if slot == 5210)); } } diff --git a/redis/src/cmd.rs b/redis/src/cmd.rs index f75d952fa..5035cf0e7 100644 --- a/redis/src/cmd.rs +++ b/redis/src/cmd.rs @@ -1,7 +1,8 @@ #[cfg(feature = "aio")] use futures_util::{ + future::BoxFuture, task::{Context, Poll}, - FutureExt, Stream, + Stream, StreamExt, }; #[cfg(feature = "aio")] use std::pin::Pin; @@ -70,30 +71,30 @@ impl<'a, T: FromRedisValue> Iterator for Iter<'a, T> { #[cfg(feature = "aio")] use crate::aio::ConnectionLike as AsyncConnection; -/// Represents a redis iterator that can be used with async connections. +/// The inner future of AsyncIter #[cfg(feature = "aio")] -pub struct AsyncIter<'a, T: FromRedisValue + 'a> { +struct AsyncIterInner<'a, T: FromRedisValue + 'a> { batch: std::vec::IntoIter, con: &'a mut (dyn AsyncConnection + Send + 'a), cmd: Cmd, } +/// Represents the state of AsyncIter #[cfg(feature = "aio")] -impl<'a, T: FromRedisValue + 'a> AsyncIter<'a, T> { - /// ```rust,no_run - /// # use redis::AsyncCommands; - /// # async fn scan_set() -> redis::RedisResult<()> { - /// # let client = redis::Client::open("redis://127.0.0.1/")?; - /// # let mut con = client.get_async_connection().await?; - /// con.sadd("my_set", 42i32).await?; - /// con.sadd("my_set", 43i32).await?; - /// let mut iter: redis::AsyncIter = con.sscan("my_set").await?; - /// while let Some(element) = iter.next_item().await { - /// assert!(element == 42 || element == 43); - /// } - /// # Ok(()) - /// # } - /// ``` +enum IterOrFuture<'a, T: FromRedisValue + 'a> { + Iter(AsyncIterInner<'a, T>), + Future(BoxFuture<'a, (AsyncIterInner<'a, T>, Option)>), + Empty, +} + +/// Represents a redis iterator that can be used with async connections. +#[cfg(feature = "aio")] +pub struct AsyncIter<'a, T: FromRedisValue + 'a> { + inner: IterOrFuture<'a, T>, +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a> AsyncIterInner<'a, T> { #[inline] pub async fn next_item(&mut self) -> Option { // we need to do this in a loop until we produce at least one item @@ -125,13 +126,55 @@ impl<'a, T: FromRedisValue + 'a> AsyncIter<'a, T> { } #[cfg(feature = "aio")] -impl<'a, T: FromRedisValue + Unpin + 'a> Stream for AsyncIter<'a, T> { +impl<'a, T: FromRedisValue + 'a + Unpin + Send> AsyncIter<'a, T> { + /// ```rust,no_run + /// # use redis::AsyncCommands; + /// # async fn scan_set() -> redis::RedisResult<()> { + /// # let client = redis::Client::open("redis://127.0.0.1/")?; + /// # let mut con = client.get_async_connection().await?; + /// con.sadd("my_set", 42i32).await?; + /// con.sadd("my_set", 43i32).await?; + /// let mut iter: redis::AsyncIter = con.sscan("my_set").await?; + /// while let Some(element) = iter.next_item().await { + /// assert!(element == 42 || element == 43); + /// } + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub async fn next_item(&mut self) -> Option { + StreamExt::next(self).await + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + Unpin + Send + 'a> Stream for AsyncIter<'a, T> { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - let mut future = Box::pin(this.next_item()); - future.poll_unpin(cx) + let mut this = self.get_mut(); + let inner = std::mem::replace(&mut this.inner, IterOrFuture::Empty); + match inner { + IterOrFuture::Iter(mut iter) => { + let fut = async move { + let next_item = iter.next_item().await; + (iter, next_item) + }; + this.inner = IterOrFuture::Future(Box::pin(fut)); + Pin::new(this).poll_next(cx) + } + IterOrFuture::Future(mut fut) => match fut.as_mut().poll(cx) { + Poll::Pending => { + this.inner = IterOrFuture::Future(fut); + Poll::Pending + } + Poll::Ready((iter, value)) => { + this.inner = IterOrFuture::Iter(iter); + Poll::Ready(value) + } + }, + IterOrFuture::Empty => unreachable!(), + } } } @@ -236,7 +279,7 @@ impl RedisWrite for Cmd { fn write_arg_fmt(&mut self, arg: impl fmt::Display) { use std::io::Write; - write!(self.data, "{}", arg).unwrap(); + write!(self.data, "{arg}").unwrap(); self.args.push(Arg::Simple(self.data.len())); } } @@ -416,7 +459,7 @@ impl Cmd { /// Similar to `iter()` but returns an AsyncIter over the items of the /// bulk result or iterator. A [futures::Stream](https://docs.rs/futures/0.3.3/futures/stream/trait.Stream.html) - /// can be obtained by calling `stream()` on the AsyncIter. In normal mode this is not in any way more + /// is implemented on AsyncIter. In normal mode this is not in any way more /// efficient than just querying into a `Vec` as it's internally /// implemented as buffering into a vector. This however is useful when /// `cursor_arg` was used in which case the stream will query for more @@ -449,9 +492,11 @@ impl Cmd { } Ok(AsyncIter { - batch: batch.into_iter(), - con, - cmd: self, + inner: IterOrFuture::Iter(AsyncIterInner { + batch: batch.into_iter(), + con, + cmd: self, + }), }) } diff --git a/redis/src/commands/mod.rs b/redis/src/commands/mod.rs index 64bbdf82c..19c115300 100644 --- a/redis/src/commands/mod.rs +++ b/redis/src/commands/mod.rs @@ -72,6 +72,11 @@ implement_commands! { cmd(if key.is_single_arg() { "GET" } else { "MGET" }).arg(key) } + /// Get values of keys + fn mget(key: K){ + cmd("MGET").arg(key) + } + /// Gets all keys matching pattern fn keys(key: K) { cmd("KEYS").arg(key) @@ -83,10 +88,17 @@ implement_commands! { } /// Sets multiple keys to their values. + #[allow(deprecated)] + #[deprecated(since = "0.22.4", note = "Renamed to mset() to reflect Redis name")] fn set_multiple(items: &'a [(K, V)]) { cmd("MSET").arg(items) } + /// Sets multiple keys to their values. + fn mset(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + /// Set the value and expiration of a key. fn set_ex(key: K, value: V, seconds: usize) { cmd("SETEX").arg(key).arg(seconds).arg(value) @@ -186,12 +198,12 @@ implement_commands! { } /// Rename a key. - fn rename(key: K, new_key: K) { + fn rename(key: K, new_key: N) { cmd("RENAME").arg(key).arg(new_key) } /// Rename a key, only if the new key does not exist. - fn rename_nx(key: K, new_key: K) { + fn rename_nx(key: K, new_key: N) { cmd("RENAMENX").arg(key).arg(new_key) } @@ -224,7 +236,7 @@ implement_commands! { /// Sets or clears the bit at offset in the string value stored at key. fn setbit(key: K, offset: usize, value: bool) { - cmd("SETBIT").arg(key).arg(offset).arg(if value {1} else {0}) + cmd("SETBIT").arg(key).arg(offset).arg(i32::from(value)) } /// Returns the bit value at offset in the string value stored at key. @@ -244,25 +256,25 @@ implement_commands! { /// Perform a bitwise AND between multiple keys (containing string values) /// and store the result in the destination key. - fn bit_and(dstkey: K, srckeys: K) { + fn bit_and(dstkey: D, srckeys: S) { cmd("BITOP").arg("AND").arg(dstkey).arg(srckeys) } /// Perform a bitwise OR between multiple keys (containing string values) /// and store the result in the destination key. - fn bit_or(dstkey: K, srckeys: K) { + fn bit_or(dstkey: D, srckeys: S) { cmd("BITOP").arg("OR").arg(dstkey).arg(srckeys) } /// Perform a bitwise XOR between multiple keys (containing string values) /// and store the result in the destination key. - fn bit_xor(dstkey: K, srckeys: K) { + fn bit_xor(dstkey: D, srckeys: S) { cmd("BITOP").arg("XOR").arg(dstkey).arg(srckeys) } /// Perform a bitwise NOT of the key (containing string values) /// and store the result in the destination key. - fn bit_not(dstkey: K, srckey: K) { + fn bit_not(dstkey: D, srckey: S) { cmd("BITOP").arg("NOT").arg(dstkey).arg(srckey) } @@ -336,7 +348,7 @@ implement_commands! { /// Pop an element from a list, push it to another list /// and return it; or block until one is available - fn blmove(srckey: K, dstkey: K, src_dir: Direction, dst_dir: Direction, timeout: usize) { + fn blmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction, timeout: usize) { cmd("BLMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir).arg(timeout) } @@ -358,7 +370,7 @@ implement_commands! { /// Pop a value from a list, push it to another list and return it; /// or block until one is available. - fn brpoplpush(srckey: K, dstkey: K, timeout: usize) { + fn brpoplpush(srckey: S, dstkey: D, timeout: usize) { cmd("BRPOPLPUSH").arg(srckey).arg(dstkey).arg(timeout) } @@ -385,7 +397,7 @@ implement_commands! { } /// Pop an element a list, push it to another list and return it - fn lmove(srckey: K, dstkey: K, src_dir: Direction, dst_dir: Direction) { + fn lmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction) { cmd("LMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir) } @@ -448,7 +460,7 @@ implement_commands! { } /// Pop a value from a list, push it to another list and return it. - fn rpoplpush(key: K, dstkey: K) { + fn rpoplpush(key: K, dstkey: D) { cmd("RPOPLPUSH").arg(key).arg(dstkey) } @@ -481,7 +493,7 @@ implement_commands! { } /// Subtract multiple sets and store the resulting set in a key. - fn sdiffstore(dstkey: K, keys: K) { + fn sdiffstore(dstkey: D, keys: K) { cmd("SDIFFSTORE").arg(dstkey).arg(keys) } @@ -491,7 +503,7 @@ implement_commands! { } /// Intersect multiple sets and store the resulting set in a key. - fn sinterstore(dstkey: K, keys: K) { + fn sinterstore(dstkey: D, keys: K) { cmd("SINTERSTORE").arg(dstkey).arg(keys) } @@ -506,7 +518,7 @@ implement_commands! { } /// Move a member from one set to another. - fn smove(srckey: K, dstkey: K, member: M) { + fn smove(srckey: S, dstkey: D, member: M) { cmd("SMOVE").arg(srckey).arg(dstkey).arg(member) } @@ -536,7 +548,7 @@ implement_commands! { } /// Add multiple sets and store the resulting set in a key. - fn sunionstore(dstkey: K, keys: K) { + fn sunionstore(dstkey: D, keys: K) { cmd("SUNIONSTORE").arg(dstkey).arg(keys) } @@ -570,26 +582,26 @@ implement_commands! { /// Intersect multiple sorted sets and store the resulting sorted set in /// a new key using SUM as aggregation function. - fn zinterstore(dstkey: K, keys: &'a [K]) { + fn zinterstore(dstkey: D, keys: &'a [K]) { cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys) } /// Intersect multiple sorted sets and store the resulting sorted set in /// a new key using MIN as aggregation function. - fn zinterstore_min(dstkey: K, keys: &'a [K]) { + fn zinterstore_min(dstkey: D, keys: &'a [K]) { cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") } /// Intersect multiple sorted sets and store the resulting sorted set in /// a new key using MAX as aggregation function. - fn zinterstore_max(dstkey: K, keys: &'a [K]) { + fn zinterstore_max(dstkey: D, keys: &'a [K]) { cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") } /// [`Commands::zinterstore`], but with the ability to specify a /// multiplication factor for each sorted set by pairing one with each key /// in a tuple. - fn zinterstore_weights(dstkey: K, keys: &'a [(K, W)]) { + fn zinterstore_weights(dstkey: D, keys: &'a [(K, W)]) { let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight)| (key, weight)).unzip(); cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) } @@ -597,7 +609,7 @@ implement_commands! { /// [`Commands::zinterstore_min`], but with the ability to specify a /// multiplication factor for each sorted set by pairing one with each key /// in a tuple. - fn zinterstore_min_weights(dstkey: K, keys: &'a [(K, W)]) { + fn zinterstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight)| (key, weight)).unzip(); cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) } @@ -605,13 +617,13 @@ implement_commands! { /// [`Commands::zinterstore_max`], but with the ability to specify a /// multiplication factor for each sorted set by pairing one with each key /// in a tuple. - fn zinterstore_max_weights(dstkey: K, keys: &'a [(K, W)]) { + fn zinterstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight)| (key, weight)).unzip(); cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) } /// Count the number of members in a sorted set between a given lexicographical range. - fn zlexcount(key: K, min: L, max: L) { + fn zlexcount(key: K, min: M, max: MM) { cmd("ZLEXCOUNT").arg(key).arg(min).arg(max) } @@ -781,26 +793,26 @@ implement_commands! { /// Unions multiple sorted sets and store the resulting sorted set in /// a new key using SUM as aggregation function. - fn zunionstore(dstkey: K, keys: &'a [K]) { + fn zunionstore(dstkey: D, keys: &'a [K]) { cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys) } /// Unions multiple sorted sets and store the resulting sorted set in /// a new key using MIN as aggregation function. - fn zunionstore_min(dstkey: K, keys: &'a [K]) { + fn zunionstore_min(dstkey: D, keys: &'a [K]) { cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") } /// Unions multiple sorted sets and store the resulting sorted set in /// a new key using MAX as aggregation function. - fn zunionstore_max(dstkey: K, keys: &'a [K]) { + fn zunionstore_max(dstkey: D, keys: &'a [K]) { cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") } /// [`Commands::zunionstore`], but with the ability to specify a /// multiplication factor for each sorted set by pairing one with each key /// in a tuple. - fn zunionstore_weights(dstkey: K, keys: &'a [(K, W)]) { + fn zunionstore_weights(dstkey: D, keys: &'a [(K, W)]) { let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight)| (key, weight)).unzip(); cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) } @@ -808,7 +820,7 @@ implement_commands! { /// [`Commands::zunionstore_min`], but with the ability to specify a /// multiplication factor for each sorted set by pairing one with each key /// in a tuple. - fn zunionstore_min_weights(dstkey: K, keys: &'a [(K, W)]) { + fn zunionstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight)| (key, weight)).unzip(); cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) } @@ -816,7 +828,7 @@ implement_commands! { /// [`Commands::zunionstore_max`], but with the ability to specify a /// multiplication factor for each sorted set by pairing one with each key /// in a tuple. - fn zunionstore_max_weights(dstkey: K, keys: &'a [(K, W)]) { + fn zunionstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight)| (key, weight)).unzip(); cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) } @@ -835,7 +847,7 @@ implement_commands! { } /// Merge N different HyperLogLogs into a single one. - fn pfmerge(dstkey: K, srckeys: K) { + fn pfmerge(dstkey: D, srckeys: S) { cmd("PFMERGE").arg(dstkey).arg(srckeys) } @@ -1888,7 +1900,7 @@ pub enum ControlFlow { /// 10 => ControlFlow::Break(()), /// _ => ControlFlow::Continue, /// } -/// }); +/// })?; /// # Ok(()) } /// ``` // TODO In the future, it would be nice to implement Try such that `?` will work diff --git a/redis/src/connection.rs b/redis/src/connection.rs index 82732c7a1..172a226e4 100644 --- a/redis/src/connection.rs +++ b/redis/src/connection.rs @@ -18,9 +18,22 @@ use crate::types::HashMap; #[cfg(unix)] use std::os::unix::net::UnixStream; -#[cfg(feature = "tls")] +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] use native_tls::{TlsConnector, TlsStream}; +#[cfg(feature = "tls-rustls")] +use rustls::{RootCertStore, StreamOwned}; +#[cfg(feature = "tls-rustls")] +use std::{convert::TryInto, sync::Arc}; + +#[cfg(feature = "tls-rustls-webpki-roots")] +use rustls::OwnedTrustAnchor; +#[cfg(feature = "tls-rustls-webpki-roots")] +use webpki_roots::TLS_SERVER_ROOTS; + +#[cfg(all(feature = "tls-rustls", not(feature = "tls-rustls-webpki-roots")))] +use rustls_native_certs::load_native_certs; + static DEFAULT_PORT: u16 = 6379; /// This function takes a redis URL string and parses it into a URL @@ -76,7 +89,9 @@ impl ConnectionAddr { pub fn is_supported(&self) -> bool { match *self { ConnectionAddr::Tcp(_, _) => true, - ConnectionAddr::TcpTls { .. } => cfg!(feature = "tls"), + ConnectionAddr::TcpTls { .. } => { + cfg!(any(feature = "tls-native-tls", feature = "tls-rustls")) + } ConnectionAddr::Unix(_) => cfg!(unix), } } @@ -84,9 +99,10 @@ impl ConnectionAddr { impl fmt::Display for ConnectionAddr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Cluster::get_connection_info depends on the return value from this function match *self { - ConnectionAddr::Tcp(ref host, port) => write!(f, "{}:{}", host, port), - ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{}:{}", host, port), + ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"), + ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"), ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()), } } @@ -189,7 +205,7 @@ fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { }; let port = url.port().unwrap_or(DEFAULT_PORT); let addr = if url.scheme() == "rediss" { - #[cfg(feature = "tls")] + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] { match url.fragment() { Some("insecure") => ConnectionAddr::TcpTls { @@ -209,7 +225,7 @@ fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { } } - #[cfg(not(feature = "tls"))] + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] fail!(( ErrorKind::InvalidClientConfig, "can't connect with TLS, the feature is not enabled" @@ -300,12 +316,18 @@ struct TcpConnection { open: bool, } -#[cfg(feature = "tls")] -struct TcpTlsConnection { +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +struct TcpNativeTlsConnection { reader: TlsStream, open: bool, } +#[cfg(feature = "tls-rustls")] +struct TcpRustlsConnection { + reader: StreamOwned, + open: bool, +} + #[cfg(unix)] struct UnixConnection { sock: UnixStream, @@ -314,12 +336,32 @@ struct UnixConnection { enum ActualConnection { Tcp(TcpConnection), - #[cfg(feature = "tls")] - TcpTls(Box), + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + TcpNativeTls(Box), + #[cfg(feature = "tls-rustls")] + TcpRustls(Box), #[cfg(unix)] Unix(UnixConnection), } +#[cfg(feature = "tls-rustls-insecure")] +struct NoCertificateVerification; + +#[cfg(feature = "tls-rustls-insecure")] +impl rustls::client::ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp: &[u8], + _now: std::time::SystemTime, + ) -> Result { + Ok(rustls::client::ServerCertVerified::assertion()) + } +} + /// Represents a stateful redis TCP connection. pub struct Connection { con: ActualConnection, @@ -386,7 +428,7 @@ impl ActualConnection { open: true, }) } - #[cfg(feature = "tls")] + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] ConnectionAddr::TcpTls { ref host, port, @@ -440,12 +482,57 @@ impl ActualConnection { } } }; - ActualConnection::TcpTls(Box::new(TcpTlsConnection { + ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection { reader: tls, open: true, })) } - #[cfg(not(feature = "tls"))] + #[cfg(feature = "tls-rustls")] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + } => { + let host: &str = host; + let config = create_rustls_config(insecure)?; + let conn = rustls::ClientConnection::new(Arc::new(config), host.try_into()?)?; + let reader = match timeout { + None => { + let tcp = TcpStream::connect((host, port))?; + StreamOwned::new(conn, tcp) + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host, port).to_socket_addrs()? { + match TcpStream::connect_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => StreamOwned::new(conn, tcp), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + + ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true })) + } + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] ConnectionAddr::TcpTls { .. } => { fail!(( ErrorKind::InvalidClientConfig, @@ -482,8 +569,21 @@ impl ActualConnection { Ok(_) => Ok(Value::Okay), } } - #[cfg(feature = "tls")] - ActualConnection::TcpTls(ref mut connection) => { + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_connection_dropped() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { let res = connection.reader.write_all(bytes).map_err(RedisError::from); match res { Err(e) => { @@ -516,8 +616,13 @@ impl ActualConnection { ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { reader.set_write_timeout(dur)?; } - #[cfg(feature = "tls")] - ActualConnection::TcpTls(ref boxed_tls_connection) => { + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { let reader = &(boxed_tls_connection.reader); reader.get_ref().set_write_timeout(dur)?; } @@ -534,8 +639,13 @@ impl ActualConnection { ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { reader.set_read_timeout(dur)?; } - #[cfg(feature = "tls")] - ActualConnection::TcpTls(ref boxed_tls_connection) => { + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { let reader = &(boxed_tls_connection.reader); reader.get_ref().set_read_timeout(dur)?; } @@ -550,14 +660,60 @@ impl ActualConnection { pub fn is_open(&self) -> bool { match *self { ActualConnection::Tcp(TcpConnection { open, .. }) => open, - #[cfg(feature = "tls")] - ActualConnection::TcpTls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open, #[cfg(unix)] ActualConnection::Unix(UnixConnection { open, .. }) => open, } } } +#[cfg(feature = "tls-rustls")] +pub(crate) fn create_rustls_config(insecure: bool) -> RedisResult { + let mut root_store = RootCertStore::empty(); + #[cfg(feature = "tls-rustls-webpki-roots")] + root_store.add_server_trust_anchors(TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + #[cfg(all(feature = "tls-rustls", not(feature = "tls-rustls-webpki-roots")))] + for cert in load_native_certs()? { + root_store.add(&rustls::Certificate(cert.0))?; + } + + let config = rustls::ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(rustls::ALL_VERSIONS)? + .with_root_certificates(root_store) + .with_no_client_auth(); + + match (insecure, cfg!(feature = "tls-rustls-insecure")) { + #[cfg(feature = "tls-rustls-insecure")] + (true, true) => { + let mut config = config; + config.enable_sni = false; + config + .dangerous() + .set_certificate_verifier(Arc::new(NoCertificateVerification)); + + Ok(config) + } + (true, false) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot create insecure client without tls-rustls-insecure feature" + )); + } + _ => Ok(config), + } +} + fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> { let mut command = cmd("AUTH"); if let Some(username) = &connection_info.username { @@ -808,8 +964,13 @@ impl Connection { ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => { self.parser.parse_value(reader) } - #[cfg(feature = "tls")] - ActualConnection::TcpTls(ref mut boxed_tls_connection) => { + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + self.parser.parse_value(reader) + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut boxed_tls_connection) => { let reader = &mut boxed_tls_connection.reader; self.parser.parse_value(reader) } @@ -830,11 +991,16 @@ impl Connection { let _ = connection.reader.shutdown(net::Shutdown::Both); connection.open = false; } - #[cfg(feature = "tls")] - ActualConnection::TcpTls(ref mut connection) => { + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { let _ = connection.reader.shutdown(); connection.open = false; } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both); + connection.open = false; + } #[cfg(unix)] ActualConnection::Unix(ref mut connection) => { let _ = connection.sock.shutdown(net::Shutdown::Both); @@ -1182,8 +1348,7 @@ mod tests { assert_eq!( res.is_some(), expected, - "Parsed result of `{}` is not expected", - url, + "Parsed result of `{url}` is not expected", ); } } @@ -1219,21 +1384,18 @@ mod tests { ]; for (url, expected) in cases.into_iter() { let res = url_to_tcp_connection_info(url.clone()).unwrap(); - assert_eq!(res.addr, expected.addr, "addr of {} is not expected", url); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); assert_eq!( res.redis.db, expected.redis.db, - "db of {} is not expected", - url + "db of {url} is not expected", ); assert_eq!( res.redis.username, expected.redis.username, - "username of {} is not expected", - url + "username of {url} is not expected", ); assert_eq!( res.redis.password, expected.redis.password, - "password of {} is not expected", - url + "password of {url} is not expected", ); } } @@ -1331,25 +1493,21 @@ mod tests { assert_eq!( ConnectionAddr::Unix(url.to_file_path().unwrap()), expected.addr, - "addr of {} is not expected", - url + "addr of {url} is not expected", ); let res = url_to_unix_connection_info(url.clone()).unwrap(); - assert_eq!(res.addr, expected.addr, "addr of {} is not expected", url); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); assert_eq!( res.redis.db, expected.redis.db, - "db of {} is not expected", - url + "db of {url} is not expected", ); assert_eq!( res.redis.username, expected.redis.username, - "username of {} is not expected", - url + "username of {url} is not expected", ); assert_eq!( res.redis.password, expected.redis.password, - "password of {} is not expected", - url + "password of {url} is not expected", ); } } diff --git a/redis/src/lib.rs b/redis/src/lib.rs index 09ab61df8..dacbc2f63 100644 --- a/redis/src/lib.rs +++ b/redis/src/lib.rs @@ -1,4 +1,4 @@ -//! redis-rs is a rust implementation of a Redis client library. It exposes +//! redis-rs is a Rust implementation of a Redis client library. It exposes //! a general purpose interface to Redis and also provides specific helpers for //! commonly used functionality. //! @@ -59,6 +59,7 @@ //! * `r2d2`: enables r2d2 connection pool support (optional) //! * `ahash`: enables ahash map/set support & uses ahash internally (+7-10% performance) (optional) //! * `cluster`: enables redis cluster support (optional) +//! * `cluster-async`: enables async redis cluster support (optional) //! * `tokio-comp`: enables support for tokio (optional) //! * `connection-manager`: enables support for automatic reconnection (optional) //! @@ -324,8 +325,9 @@ assert_eq!(result, 3); In addition to the synchronous interface that's been explained above there also exists an asynchronous interface based on [`futures`][] and [`tokio`][]. -This interface exists under the `aio` (async io) module and largely mirrors the synchronous -with a few concessions to make it fit the constraints of `futures`. +This interface exists under the `aio` (async io) module (which requires that the `aio` feature +is enabled) and largely mirrors the synchronous with a few concessions to make it fit the +constraints of `futures`. ```rust,no_run use futures::prelude::*; @@ -448,6 +450,9 @@ mod r2d2; #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] pub mod streams; +#[cfg(feature = "cluster-async")] +pub mod cluster_async; + mod client; mod cmd; mod commands; diff --git a/redis/src/parser.rs b/redis/src/parser.rs index cc0bda8a5..45a845ed5 100644 --- a/redis/src/parser.rs +++ b/redis/src/parser.rs @@ -53,102 +53,116 @@ where } } +const MAX_RECURSE_DEPTH: usize = 100; + fn value<'a, I>( + count: Option, ) -> impl combine::Parser, PartialState = AnySendSyncPartialState> where I: RangeStream, I::Error: combine::ParseError, { - opaque!(any_send_sync_partial_state(any().then_partial( - move |&mut b| { - let line = || { - recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then( - |line: &[u8]| { - str::from_utf8(&line[..line.len() - 2]).map_err(StreamErrorFor::::other) - }, + let count = count.unwrap_or(1); + + opaque!(any_send_sync_partial_state( + any() + .then_partial(move |&mut b| { + if b == b'*' && count > MAX_RECURSE_DEPTH { + combine::unexpected_any("Maximum recursion depth exceeded").left() + } else { + combine::value(b).right() + } + }) + .then_partial(move |&mut b| { + let line = || { + recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then( + |line: &[u8]| { + str::from_utf8(&line[..line.len() - 2]) + .map_err(StreamErrorFor::::other) + }, + ) + }; + + let status = || { + line().map(|line| { + if line == "OK" { + Value::Okay + } else { + Value::Status(line.into()) + } + }) + }; + + let int = || { + line().and_then(|line| match line.trim().parse::() { + Err(_) => Err(StreamErrorFor::::message_static_message( + "Expected integer, got garbage", + )), + Ok(value) => Ok(value), + }) + }; + + let data = || { + int().then_partial(move |size| { + if *size < 0 { + combine::value(Value::Nil).left() + } else { + take(*size as usize) + .map(|bs: &[u8]| Value::Data(bs.to_vec())) + .skip(crlf()) + .right() + } + }) + }; + + let bulk = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::value(Value::Nil).map(Ok).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(|result: ResultExtend<_, _>| result.0.map(Value::Bulk)) + .right() + } + }) + }; + + let error = || { + line().map(|line: &str| { + let desc = "An error was signalled by the server"; + let mut pieces = line.splitn(2, ' '); + let kind = match pieces.next().unwrap() { + "ERR" => ErrorKind::ResponseError, + "EXECABORT" => ErrorKind::ExecAbortError, + "LOADING" => ErrorKind::BusyLoadingError, + "NOSCRIPT" => ErrorKind::NoScriptError, + "MOVED" => ErrorKind::Moved, + "ASK" => ErrorKind::Ask, + "TRYAGAIN" => ErrorKind::TryAgain, + "CLUSTERDOWN" => ErrorKind::ClusterDown, + "CROSSSLOT" => ErrorKind::CrossSlot, + "MASTERDOWN" => ErrorKind::MasterDown, + "READONLY" => ErrorKind::ReadOnly, + code => return make_extension_error(code, pieces.next()), + }; + match pieces.next() { + Some(detail) => RedisError::from((kind, desc, detail.to_string())), + None => RedisError::from((kind, desc)), + } + }) + }; + + combine::dispatch!(b; + b'+' => status().map(Ok), + b':' => int().map(|i| Ok(Value::Int(i))), + b'$' => data().map(Ok), + b'*' => bulk(), + b'-' => error().map(Err), + b => combine::unexpected_any(combine::error::Token(b)) ) - }; - - let status = || { - line().map(|line| { - if line == "OK" { - Value::Okay - } else { - Value::Status(line.into()) - } - }) - }; - - let int = || { - line().and_then(|line| match line.trim().parse::() { - Err(_) => Err(StreamErrorFor::::message_static_message( - "Expected integer, got garbage", - )), - Ok(value) => Ok(value), - }) - }; - - let data = || { - int().then_partial(move |size| { - if *size < 0 { - combine::value(Value::Nil).left() - } else { - take(*size as usize) - .map(|bs: &[u8]| Value::Data(bs.to_vec())) - .skip(crlf()) - .right() - } - }) - }; - - let bulk = || { - int().then_partial(|&mut length| { - if length < 0 { - combine::value(Value::Nil).map(Ok).left() - } else { - let length = length as usize; - combine::count_min_max(length, length, value()) - .map(|result: ResultExtend<_, _>| result.0.map(Value::Bulk)) - .right() - } - }) - }; - - let error = || { - line().map(|line: &str| { - let desc = "An error was signalled by the server"; - let mut pieces = line.splitn(2, ' '); - let kind = match pieces.next().unwrap() { - "ERR" => ErrorKind::ResponseError, - "EXECABORT" => ErrorKind::ExecAbortError, - "LOADING" => ErrorKind::BusyLoadingError, - "NOSCRIPT" => ErrorKind::NoScriptError, - "MOVED" => ErrorKind::Moved, - "ASK" => ErrorKind::Ask, - "TRYAGAIN" => ErrorKind::TryAgain, - "CLUSTERDOWN" => ErrorKind::ClusterDown, - "CROSSSLOT" => ErrorKind::CrossSlot, - "MASTERDOWN" => ErrorKind::MasterDown, - "READONLY" => ErrorKind::ReadOnly, - code => return make_extension_error(code, pieces.next()), - }; - match pieces.next() { - Some(detail) => RedisError::from((kind, desc, detail.to_string())), - None => RedisError::from((kind, desc)), - } - }) - }; - - combine::dispatch!(b; - b'+' => status().map(Ok), - b':' => int().map(|i| Ok(Value::Int(i))), - b'$' => data().map(Ok), - b'*' => bulk(), - b'-' => error().map(Err), - b => combine::unexpected_any(combine::error::Token(b)) - ) - } - ))) + }) + )) } #[cfg(feature = "aio")] @@ -174,12 +188,12 @@ mod aio_support { let buffer = &bytes[..]; let mut stream = combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof)); - match combine::stream::decode_tokio(value(), &mut stream, &mut self.state) { + match combine::stream::decode_tokio(value(None), &mut stream, &mut self.state) { Ok(x) => x, Err(err) => { let err = err .map_position(|pos| pos.translate_position(buffer)) - .map_range(|range| format!("{:?}", range)) + .map_range(|range| format!("{range:?}")) .to_string(); return Err(RedisError::from(( ErrorKind::ResponseError, @@ -227,7 +241,7 @@ mod aio_support { where R: AsyncRead + std::marker::Unpin, { - let result = combine::decode_tokio!(*decoder, *read, value(), |input, _| { + let result = combine::decode_tokio!(*decoder, *read, value(None), |input, _| { combine::stream::easy::Stream::from(input) }); match result { @@ -238,7 +252,7 @@ mod aio_support { RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) } else { let err = err - .map_range(|range| format!("{:?}", range)) + .map_range(|range| format!("{range:?}")) .map_position(|pos| pos.translate_position(decoder.buffer())) .to_string(); RedisError::from((ErrorKind::ResponseError, "parse error", err)) @@ -285,7 +299,7 @@ impl Parser { /// Parses synchronously into a single value from the reader. pub fn parse_value(&mut self, mut reader: T) -> RedisResult { let mut decoder = &mut self.decoder; - let result = combine::decode!(decoder, reader, value(), |input, _| { + let result = combine::decode!(decoder, reader, value(None), |input, _| { combine::stream::easy::Stream::from(input) }); match result { @@ -296,7 +310,7 @@ impl Parser { RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) } else { let err = err - .map_range(|range| format!("{:?}", range)) + .map_range(|range| format!("{range:?}")) .map_position(|pos| pos.translate_position(decoder.buffer())) .to_string(); RedisError::from((ErrorKind::ResponseError, "parse error", err)) @@ -319,7 +333,7 @@ pub fn parse_redis_value(bytes: &[u8]) -> RedisResult { #[cfg(test)] mod tests { - #[cfg(feature = "aio")] + use super::*; #[cfg(feature = "aio")] @@ -336,4 +350,13 @@ mod tests { assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); } + + #[test] + fn test_max_recursion_depth() { + let bytes = b"*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n"; + match parse_redis_value(bytes) { + Ok(_) => panic!("Expected Err"), + Err(e) => assert!(matches!(e.kind(), ErrorKind::ResponseError)), + } + } } diff --git a/redis/src/script.rs b/redis/src/script.rs index aee066422..8716b482f 100644 --- a/redis/src/script.rs +++ b/redis/src/script.rs @@ -86,6 +86,23 @@ impl Script { } .invoke(con) } + + /// Asynchronously invokes the script without arguments. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke_async(con) + .await + } } /// Represents a prepared script call. diff --git a/redis/src/streams.rs b/redis/src/streams.rs index 8c0ec0428..2b3e815a9 100644 --- a/redis/src/streams.rs +++ b/redis/src/streams.rs @@ -40,45 +40,45 @@ impl ToRedisArgs for StreamMaxlen { /// #[derive(Default, Debug)] pub struct StreamClaimOptions { - /// Set IDLE cmd arg. + /// Set `IDLE ` cmd arg. idle: Option, - /// Set TIME cmd arg. + /// Set `TIME ` cmd arg. time: Option, - /// Set RETRYCOUNT cmd arg. + /// Set `RETRYCOUNT ` cmd arg. retry: Option, - /// Set FORCE cmd arg. + /// Set `FORCE` cmd arg. force: bool, - /// Set JUSTID cmd arg. Be advised: the response + /// Set `JUSTID` cmd arg. Be advised: the response /// type changes with this option. justid: bool, } impl StreamClaimOptions { - /// Set IDLE cmd arg. + /// Set `IDLE ` cmd arg. pub fn idle(mut self, ms: usize) -> Self { self.idle = Some(ms); self } - /// Set TIME cmd arg. + /// Set `TIME ` cmd arg. pub fn time(mut self, ms_time: usize) -> Self { self.time = Some(ms_time); self } - /// Set RETRYCOUNT cmd arg. + /// Set `RETRYCOUNT ` cmd arg. pub fn retry(mut self, count: usize) -> Self { self.retry = Some(count); self } - /// Set FORCE cmd arg to true. + /// Set `FORCE` cmd arg to true. pub fn with_force(mut self) -> Self { self.force = true; self } - /// Set JUSTID cmd arg to true. Be advised: the response + /// Set `JUSTID` cmd arg to true. Be advised: the response /// type changes with this option. pub fn with_justid(mut self) -> Self { self.justid = true; @@ -93,15 +93,15 @@ impl ToRedisArgs for StreamClaimOptions { { if let Some(ref ms) = self.idle { out.write_arg(b"IDLE"); - out.write_arg(format!("{}", ms).as_bytes()); + out.write_arg(format!("{ms}").as_bytes()); } if let Some(ref ms_time) = self.time { out.write_arg(b"TIME"); - out.write_arg(format!("{}", ms_time).as_bytes()); + out.write_arg(format!("{ms_time}").as_bytes()); } if let Some(ref count) = self.retry { out.write_arg(b"RETRYCOUNT"); - out.write_arg(format!("{}", count).as_bytes()); + out.write_arg(format!("{count}").as_bytes()); } if self.force { out.write_arg(b"FORCE"); @@ -113,8 +113,8 @@ impl ToRedisArgs for StreamClaimOptions { } /// Argument to `StreamReadOptions` -/// Represents the Redis GROUP cmd arg. -/// This option will toggle the cmd from XREAD to XREADGROUP +/// Represents the Redis `GROUP ` cmd arg. +/// This option will toggle the cmd from `XREAD` to `XREADGROUP` type SRGroup = Option<(Vec>, Vec>)>; /// Builder options for [`xread_options`] command. /// @@ -122,13 +122,13 @@ type SRGroup = Option<(Vec>, Vec>)>; /// #[derive(Default, Debug)] pub struct StreamReadOptions { - /// Set the BLOCK cmd arg. + /// Set the `BLOCK ` cmd arg. block: Option, - /// Set the COUNT cmd arg. + /// Set the `COUNT ` cmd arg. count: Option, - /// Set the NOACK cmd arg. + /// Set the `NOACK` cmd arg. noack: Option, - /// Set the GROUP cmd arg. + /// Set the `GROUP ` cmd arg. /// This option will toggle the cmd from XREAD to XREADGROUP. group: SRGroup, } @@ -181,12 +181,12 @@ impl ToRedisArgs for StreamReadOptions { { if let Some(ref ms) = self.block { out.write_arg(b"BLOCK"); - out.write_arg(format!("{}", ms).as_bytes()); + out.write_arg(format!("{ms}").as_bytes()); } if let Some(ref n) = self.count { out.write_arg(b"COUNT"); - out.write_arg(format!("{}", n).as_bytes()); + out.write_arg(format!("{n}").as_bytes()); } if let Some(ref group) = self.group { diff --git a/redis/src/types.rs b/redis/src/types.rs index a580a167d..36121aed7 100644 --- a/redis/src/types.rs +++ b/redis/src/types.rs @@ -199,10 +199,10 @@ impl fmt::Debug for Value { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { Value::Nil => write!(fmt, "nil"), - Value::Int(val) => write!(fmt, "int({:?})", val), + Value::Int(val) => write!(fmt, "int({val:?})"), Value::Data(ref val) => match from_utf8(val) { - Ok(x) => write!(fmt, "string-data('{:?}')", x), - Err(_) => write!(fmt, "binary-data({:?})", val), + Ok(x) => write!(fmt, "string-data('{x:?}')"), + Err(_) => write!(fmt, "binary-data({val:?})"), }, Value::Bulk(ref values) => { write!(fmt, "bulk(")?; @@ -211,13 +211,13 @@ impl fmt::Debug for Value { if !is_first { write!(fmt, ", ")?; } - write!(fmt, "{:?}", val)?; + write!(fmt, "{val:?}")?; is_first = false; } write!(fmt, ")") } Value::Okay => write!(fmt, "ok"), - Value::Status(ref s) => write!(fmt, "status({:?})", s), + Value::Status(ref s) => write!(fmt, "status({s:?})"), } } } @@ -235,7 +235,7 @@ impl From for RedisError { RedisError::from(( ErrorKind::Serialize, "Serialization Error", - format!("{}", serde_err), + format!("{serde_err}"), )) } } @@ -258,9 +258,7 @@ impl PartialEq for RedisError { &ErrorRepr::WithDescriptionAndDetail(kind_a, _, _), &ErrorRepr::WithDescriptionAndDetail(kind_b, _, _), ) => kind_a == kind_b, - (&ErrorRepr::ExtensionError(ref a, _), &ErrorRepr::ExtensionError(ref b, _)) => { - *a == *b - } + (ErrorRepr::ExtensionError(a, _), ErrorRepr::ExtensionError(b, _)) => *a == *b, _ => false, } } @@ -294,7 +292,7 @@ impl From for RedisError { } } -#[cfg(feature = "tls")] +#[cfg(feature = "tls-native-tls")] impl From for RedisError { fn from(err: native_tls::Error) -> RedisError { RedisError { @@ -307,6 +305,32 @@ impl From for RedisError { } } +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls::client::InvalidDnsNameError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS Error", + err.to_string(), + ), + } + } +} + impl From for RedisError { fn from(_: FromUtf8Error) -> RedisError { RedisError { @@ -552,7 +576,7 @@ impl RedisError { } ErrorRepr::IoError(ref e) => ErrorRepr::IoError(io::Error::new( e.kind(), - format!("{}: {}", ioerror_description, e), + format!("{ioerror_description}: {e}"), )), }; Self { repr } @@ -897,11 +921,11 @@ impl<'a, T: ToRedisArgs> ToRedisArgs for &'a [T] { where W: ?Sized + RedisWrite, { - ToRedisArgs::make_arg_vec(*self, out) + ToRedisArgs::make_arg_vec(self, out) } fn is_single_arg(&self) -> bool { - ToRedisArgs::is_single_vec_arg(*self) + ToRedisArgs::is_single_vec_arg(self) } } @@ -1017,6 +1041,26 @@ impl ToRedisArgs for BTreeMap< } } +impl ToRedisArgs + for std::collections::HashMap +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + macro_rules! to_redis_args_for_tuple { () => (); ($($name:ident,)+) => ( @@ -1083,7 +1127,7 @@ to_redis_args_for_array! { /// implement it for your own types if you want. /// /// In addition to what you can see from the docs, this is also implemented -/// for tuples up to size 12 and for Vec. +/// for tuples up to size 12 and for `Vec`. pub trait FromRedisValue: Sized { /// Given a redis `Value` this attempts to convert it into the given /// destination type. If that fails because it's not compatible an @@ -1097,11 +1141,11 @@ pub trait FromRedisValue: Sized { items.iter().map(FromRedisValue::from_redis_value).collect() } - /// This only exists internally as a workaround for the lack of - /// specialization. - #[doc(hidden)] + /// Convert bytes to a single element vector. fn from_byte_vec(_vec: &[u8]) -> Option> { - None + Self::from_redis_value(&Value::Data(_vec.into())) + .map(|rv| vec![rv]) + .ok() } } @@ -1138,6 +1182,7 @@ impl FromRedisValue for u8 { from_redis_value_for_num_internal!(u8, v) } + // this hack allows us to specialize Vec to work with binary data. fn from_byte_vec(vec: &[u8]) -> Option> { Some(vec.to_vec()) } @@ -1211,11 +1256,13 @@ impl FromRedisValue for String { impl FromRedisValue for Vec { fn from_redis_value(v: &Value) -> RedisResult> { match *v { - // this hack allows us to specialize Vec to work with - // binary data whereas all others will fail with an error. + // All binary data except u8 will try to parse into a single element vector. Value::Data(ref bytes) => match FromRedisValue::from_byte_vec(bytes) { Some(x) => Ok(x), - None => invalid_type_error!(v, "Response type not vector compatible."), + None => invalid_type_error!( + v, + format!("Conversion to Vec<{}> failed.", std::any::type_name::()) + ), }, Value::Bulk(ref items) => FromRedisValue::from_redis_values(items), Value::Nil => Ok(vec![]), @@ -1228,10 +1275,16 @@ impl for std::collections::HashMap { fn from_redis_value(v: &Value) -> RedisResult> { - v.as_map_iter() - .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? - .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) - .collect() + match *v { + Value::Nil => Ok(Default::default()), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } } } @@ -1240,10 +1293,16 @@ impl for ahash::AHashMap { fn from_redis_value(v: &Value) -> RedisResult> { - v.as_map_iter() - .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? - .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) - .collect() + match *v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } } } diff --git a/redis/tests/support/cluster.rs b/redis/tests/support/cluster.rs index f0967d5ee..b92afbea5 100644 --- a/redis/tests/support/cluster.rs +++ b/redis/tests/support/cluster.rs @@ -7,6 +7,11 @@ use std::process; use std::thread::sleep; use std::time::Duration; +#[cfg(feature = "cluster-async")] +use redis::aio::ConnectionLike; +#[cfg(feature = "cluster-async")] +use redis::cluster_async::Connect; +use redis::ConnectionInfo; use tempfile::TempDir; use crate::support::build_keys_and_certs_for_tls; @@ -31,7 +36,7 @@ impl ClusterType { Some("tcp") => ClusterType::Tcp, Some("tcp+tls") => ClusterType::TcpTls, val => { - panic!("Unknown server type {:?}", val); + panic!("Unknown server type {val:?}"); } } } @@ -122,10 +127,9 @@ impl RedisCluster { cmd.arg("--tls-replication").arg("yes"); } } - cmd.current_dir(&tempdir.path()); + cmd.current_dir(tempdir.path()); folders.push(tempdir); - addrs.push(format!("127.0.0.1:{}", port)); - dbg!(&cmd); + addrs.push(format!("127.0.0.1:{port}")); cmd.spawn().unwrap() }, )); @@ -145,7 +149,7 @@ impl RedisCluster { if is_tls { cmd.arg("--tls").arg("--insecure"); } - let status = dbg!(cmd).status().unwrap(); + let status = cmd.status().unwrap(); assert!(status.success()); let cluster = RedisCluster { servers, folders }; @@ -157,10 +161,7 @@ impl RedisCluster { fn wait_for_replicas(&self, replicas: u16) { 'server: for server in &self.servers { - let conn_info = redis::ConnectionInfo { - addr: server.get_client_addr().clone(), - redis: Default::default(), - }; + let conn_info = server.connection_info(); eprintln!( "waiting until {:?} knows required number of replicas", conn_info.addr @@ -222,17 +223,15 @@ impl TestClusterContext { F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, { let cluster = RedisCluster::new(nodes, replicas); - let mut builder = redis::cluster::ClusterClientBuilder::new( - cluster - .iter_servers() - .map(|server| redis::ConnectionInfo { - addr: server.get_client_addr().clone(), - redis: Default::default(), - }) - .collect(), - ); + let initial_nodes: Vec = cluster + .iter_servers() + .map(RedisServer::connection_info) + .collect(); + let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes); builder = initializer(builder); + let client = builder.build().unwrap(); + TestClusterContext { cluster, client } } @@ -240,6 +239,23 @@ impl TestClusterContext { self.client.get_connection().unwrap() } + #[cfg(feature = "cluster-async")] + pub async fn async_connection(&self) -> redis::cluster_async::ClusterConnection { + self.client.get_async_connection().await.unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_generic_connection< + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, + >( + &self, + ) -> redis::cluster_async::ClusterConnection { + self.client + .get_async_generic_connection::() + .await + .unwrap() + } + pub fn wait_for_cluster_up(&self) { let mut con = self.connection(); let mut c = redis::cmd("CLUSTER"); @@ -256,4 +272,21 @@ impl TestClusterContext { panic!("failed waiting for cluster to be ready"); } + + pub fn disable_default_user(&self) { + for server in &self.cluster.servers { + let client = redis::Client::open(server.connection_info()).unwrap(); + let mut con = client.get_connection().unwrap(); + let _: () = redis::cmd("ACL") + .arg("SETUSER") + .arg("default") + .arg("off") + .query(&mut con) + .unwrap(); + + // subsequent unauthenticated command should fail: + let mut con = client.get_connection().unwrap(); + assert!(redis::cmd("PING").query::<()>(&mut con).is_err()); + } + } } diff --git a/redis/tests/support/mock_cluster.rs b/redis/tests/support/mock_cluster.rs new file mode 100644 index 000000000..3d4af6999 --- /dev/null +++ b/redis/tests/support/mock_cluster.rs @@ -0,0 +1,264 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::Duration, +}; + +use redis::cluster::{self, ClusterClient, ClusterClientBuilder}; + +use { + once_cell::sync::Lazy, + redis::{IntoConnectionInfo, RedisResult, Value}, +}; + +#[cfg(feature = "cluster-async")] +use redis::{aio, cluster_async, RedisFuture}; + +#[cfg(feature = "cluster-async")] +use futures::future; + +#[cfg(feature = "cluster-async")] +use tokio::runtime::Runtime; + +type Handler = Arc Result<(), RedisResult> + Send + Sync>; + +static HANDLERS: Lazy>> = Lazy::new(Default::default); + +#[derive(Clone)] +pub struct MockConnection { + pub handler: Handler, + pub port: u16, +} + +#[cfg(feature = "cluster-async")] +impl cluster_async::Connect for MockConnection { + fn connect<'a, T>(info: T) -> RedisFuture<'a, Self> + where + T: IntoConnectionInfo + Send + 'a, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + Box::pin(future::ok(MockConnection { + handler: HANDLERS + .read() + .unwrap() + .get(name) + .unwrap_or_else(|| panic!("Handler `{name}` were not installed")) + .clone(), + port, + })) + } +} + +impl cluster::Connect for MockConnection { + fn connect<'a, T>(info: T, _timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + Ok(MockConnection { + handler: HANDLERS + .read() + .unwrap() + .get(name) + .unwrap_or_else(|| panic!("Handler `{name}` were not installed")) + .clone(), + port, + }) + } + + fn send_packed_command(&mut self, _cmd: &[u8]) -> RedisResult<()> { + Ok(()) + } + + fn set_write_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn set_read_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn recv_response(&mut self) -> RedisResult { + Ok(Value::Nil) + } +} + +pub fn contains_slice(xs: &[u8], ys: &[u8]) -> bool { + for i in 0..xs.len() { + if xs[i..].starts_with(ys) { + return true; + } + } + false +} + +pub fn respond_startup(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::Status("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Bulk(vec![Value::Bulk(vec![ + Value::Int(0), + Value::Int(16383), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::Status("OK".into()))) + } else { + Ok(()) + } +} + +pub fn respond_startup_with_replica(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::Status("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Bulk(vec![Value::Bulk(vec![ + Value::Int(0), + Value::Int(16383), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ])]))) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::Status("OK".into()))) + } else { + Ok(()) + } +} + +#[cfg(feature = "cluster-async")] +impl aio::ConnectionLike for MockConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a redis::Cmd) -> RedisFuture<'a, Value> { + Box::pin(future::ready( + (self.handler)(&cmd.get_packed_command(), self.port) + .expect_err("Handler did not specify a response"), + )) + } + + fn req_packed_commands<'a>( + &'a mut self, + _pipeline: &'a redis::Pipeline, + _offset: usize, + _count: usize, + ) -> RedisFuture<'a, Vec> { + Box::pin(future::ok(vec![])) + } + + fn get_db(&self) -> i64 { + 0 + } +} + +impl redis::ConnectionLike for MockConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + (self.handler)(cmd, self.port).expect_err("Handler did not specify a response") + } + + fn req_packed_commands( + &mut self, + _cmd: &[u8], + _offset: usize, + _count: usize, + ) -> RedisResult> { + Ok(vec![]) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +pub struct MockEnv { + #[cfg(feature = "cluster-async")] + pub runtime: Runtime, + pub client: redis::cluster::ClusterClient, + pub connection: redis::cluster::ClusterConnection, + #[cfg(feature = "cluster-async")] + pub async_connection: redis::cluster_async::ClusterConnection, + #[allow(unused)] + pub handler: RemoveHandler, +} + +pub struct RemoveHandler(Vec); + +impl Drop for RemoveHandler { + fn drop(&mut self) { + for id in &self.0 { + HANDLERS.write().unwrap().remove(id); + } + } +} + +impl MockEnv { + pub fn new( + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + Self::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{id}")]), + id, + handler, + ) + } + + pub fn with_client_builder( + client_builder: ClusterClientBuilder, + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + #[cfg(feature = "cluster-async")] + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let id = id.to_string(); + HANDLERS + .write() + .unwrap() + .insert(id.clone(), Arc::new(move |cmd, port| handler(cmd, port))); + + let client = client_builder.build().unwrap(); + let connection = client.get_generic_connection().unwrap(); + #[cfg(feature = "cluster-async")] + let async_connection = runtime + .block_on(client.get_async_generic_connection()) + .unwrap(); + MockEnv { + #[cfg(feature = "cluster-async")] + runtime, + client, + connection, + #[cfg(feature = "cluster-async")] + async_connection, + handler: RemoveHandler(vec![id]), + } + } +} diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index 5d3a73ac9..73318f887 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -16,6 +16,8 @@ pub fn current_thread_runtime() -> tokio::runtime::Runtime { #[cfg(feature = "aio")] builder.enable_io(); + builder.enable_time(); + builder.build().unwrap() } @@ -33,12 +35,18 @@ where async_std::task::block_on(f) } -#[cfg(feature = "cluster")] +#[cfg(any(feature = "cluster", feature = "cluster-async"))] mod cluster; -#[cfg(feature = "cluster")] +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod mock_cluster; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] pub use self::cluster::*; +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +pub use self::mock_cluster::*; + #[derive(PartialEq)] enum ServerType { Tcp { tls: bool }, @@ -66,7 +74,7 @@ impl ServerType { Some("tcp+tls") => ServerType::Tcp { tls: true }, Some("unix") => ServerType::Unix, val => { - panic!("Unknown server type {:?}", val); + panic!("Unknown server type {val:?}"); } } } @@ -102,13 +110,13 @@ impl RedisServer { } ServerType::Unix => { let (a, b) = rand::random::<(u64, u64)>(); - let path = format!("/tmp/redis-rs-test-{}-{}.sock", a, b); + let path = format!("/tmp/redis-rs-test-{a}-{b}.sock"); redis::ConnectionAddr::Unix(PathBuf::from(&path)) } }; RedisServer::new_with_addr(addr, None, modules, |cmd| { cmd.spawn() - .unwrap_or_else(|err| panic!("Failed to run {:?}: {}", cmd, err)) + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) }) } @@ -191,7 +199,7 @@ impl RedisServer { .arg("--port") .arg("0") .arg("--unixsocket") - .arg(&path); + .arg(path); RedisServer { process: spawner(&mut redis_cmd), tempdir: Some(tempdir), @@ -201,15 +209,22 @@ impl RedisServer { } } - pub fn get_client_addr(&self) -> &redis::ConnectionAddr { + pub fn client_addr(&self) -> &redis::ConnectionAddr { &self.addr } + pub fn connection_info(&self) -> redis::ConnectionInfo { + redis::ConnectionInfo { + addr: self.client_addr().clone(), + redis: Default::default(), + } + } + pub fn stop(&mut self) { let _ = self.process.kill(); let _ = self.process.wait(); - if let redis::ConnectionAddr::Unix(ref path) = *self.get_client_addr() { - fs::remove_file(&path).ok(); + if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() { + fs::remove_file(path).ok(); } } } @@ -233,11 +248,7 @@ impl TestContext { pub fn with_modules(modules: &[Module]) -> TestContext { let server = RedisServer::with_modules(modules); - let client = redis::Client::open(redis::ConnectionInfo { - addr: server.get_client_addr().clone(), - redis: Default::default(), - }) - .unwrap(); + let client = redis::Client::open(server.connection_info()).unwrap(); let mut con; let millisecond = Duration::from_millis(1); @@ -249,10 +260,10 @@ impl TestContext { sleep(millisecond); retries += 1; if retries > 100000 { - panic!("Tried to connect too many times, last error: {}", err); + panic!("Tried to connect too many times, last error: {err}"); } } else { - panic!("Could not connect: {}", err); + panic!("Could not connect: {err}"); } } Ok(x) => { @@ -314,7 +325,7 @@ where #![allow(clippy::write_with_newline)] match *value { Value::Nil => write!(writer, "$-1\r\n"), - Value::Int(val) => write!(writer, ":{}\r\n", val), + Value::Int(val) => write!(writer, ":{val}\r\n"), Value::Data(ref val) => { write!(writer, "${}\r\n", val.len())?; writer.write_all(val)?; @@ -328,7 +339,7 @@ where Ok(()) } Value::Okay => write!(writer, "+OK\r\n"), - Value::Status(ref s) => write!(writer, "+{}\r\n", s), + Value::Status(ref s) => write!(writer, "+{s}\r\n"), } } @@ -347,13 +358,14 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { let ca_serial = tempdir.path().join("ca.txt"); let redis_crt = tempdir.path().join("redis.crt"); let redis_key = tempdir.path().join("redis.key"); + let ext_file = tempdir.path().join("openssl.cnf"); fn make_key>(name: S, size: usize) { process::Command::new("openssl") .arg("genrsa") .arg("-out") .arg(name) - .arg(&format!("{}", size)) + .arg(&format!("{size}")) .stdout(process::Stdio::null()) .stderr(process::Stdio::null()) .spawn() @@ -390,6 +402,10 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { .wait() .expect("failed to create CA cert"); + // Build x509v3 extensions file + fs::write(&ext_file, b"keyUsage = digitalSignature, keyEncipherment") + .expect("failed to create x509v3 extensions file"); + // Read redis key let mut key_cmd = process::Command::new("openssl") .arg("req") @@ -418,6 +434,8 @@ pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { .arg("-CAcreateserial") .arg("-days") .arg("365") + .arg("-extfile") + .arg(&ext_file) .arg("-out") .arg(&redis_crt) .stdin(key_cmd.stdout.take().expect("should have stdout")) diff --git a/redis/tests/test_acl.rs b/redis/tests/test_acl.rs index 069d36248..b5846d550 100644 --- a/redis/tests/test_acl.rs +++ b/redis/tests/test_acl.rs @@ -122,7 +122,7 @@ fn test_acl_cat() { "scripting", ]; for cat in expects.iter() { - assert!(res.contains(*cat), "Category `{}` does not exist", cat); + assert!(res.contains(*cat), "Category `{cat}` does not exist"); } let expects = vec!["pfmerge", "pfcount", "pfselftest", "pfadd"]; @@ -130,7 +130,7 @@ fn test_acl_cat() { .acl_cat_categoryname("hyperloglog") .expect("Got commands of a category"); for cmd in expects.iter() { - assert!(res.contains(*cmd), "Command `{}` does not exist", cmd); + assert!(res.contains(*cmd), "Command `{cmd}` does not exist"); } } diff --git a/redis/tests/test_async.rs b/redis/tests/test_async.rs index 68fb7d390..62f6ee501 100644 --- a/redis/tests/test_async.rs +++ b/redis/tests/test_async.rs @@ -1,4 +1,4 @@ -use futures::{future, prelude::*}; +use futures::{future, prelude::*, StreamExt}; use redis::{aio::MultiplexedConnection, cmd, AsyncCommands, ErrorKind, RedisResult}; use crate::support::*; @@ -136,12 +136,12 @@ fn test_pipeline_transaction_with_errors() { fn test_cmd(con: &MultiplexedConnection, i: i32) -> impl Future> + Send { let mut con = con.clone(); async move { - let key = format!("key{}", i); + let key = format!("key{i}"); let key_2 = key.clone(); - let key2 = format!("key{}_2", i); + let key2 = format!("key{i}_2"); let key2_2 = key2.clone(); - let foo_val = format!("foo{}", i); + let foo_val = format!("foo{i}"); redis::cmd("SET") .arg(&key[..]) @@ -229,7 +229,7 @@ fn test_transaction_multiplexed_connection() { let mut con = con.clone(); async move { let foo_val = i; - let bar_val = format!("bar{}", i); + let bar_val = format!("bar{i}"); let mut pipe = redis::pipe(); pipe.atomic() @@ -323,6 +323,7 @@ fn test_script() { // into Redis and when they need to be loaded in let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); let ctx = TestContext::new(); @@ -335,6 +336,8 @@ fn test_script() { .await?; let val: String = script2.key("key1").invoke_async(&mut con).await?; assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); script1 .key("key1") .arg("bar") @@ -342,6 +345,8 @@ fn test_script() { .await?; let val: String = script2.key("key1").invoke_async(&mut con).await?; assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); Ok::<_, RedisError>(()) }) .unwrap(); @@ -401,7 +406,7 @@ async fn io_error_on_kill_issue_320() { .await .unwrap(); - eprintln!("{}", client_list); + eprintln!("{client_list}"); let client_to_kill = client_list .split('\n') .find(|line| line.contains("to-kill")) @@ -437,7 +442,7 @@ async fn io_error_on_kill_issue_320() { async fn invalid_password_issue_343() { let ctx = TestContext::new(); let coninfo = redis::ConnectionInfo { - addr: ctx.server.get_client_addr().clone(), + addr: ctx.server.client_addr().clone(), redis: redis::RedisConnectionInfo { db: 0, username: None, @@ -453,11 +458,47 @@ async fn invalid_password_issue_343() { assert_eq!( err.kind(), ErrorKind::AuthenticationFailed, - "Unexpected error: {}", - err + "Unexpected error: {err}", ); } +// Test issue of Stream trait blocking if we try to iterate more than 10 items +// https://github.com/mitsuhiko/redis-rs/issues/537 and https://github.com/mitsuhiko/redis-rs/issues/583 +#[tokio::test] +async fn test_issue_stream_blocks() { + let ctx = TestContext::new(); + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + for i in 0..20usize { + let _: () = con.append(format!("test/{i}"), i).await.unwrap(); + } + let values = con.scan_match::<&str, String>("test/*").await.unwrap(); + tokio::time::timeout(std::time::Duration::from_millis(100), async move { + let values: Vec<_> = values.collect().await; + assert_eq!(values.len(), 20); + }) + .await + .unwrap(); +} + +// Test issue of AsyncCommands::scan returning the wrong number of keys +// https://github.com/redis-rs/redis-rs/issues/759 +#[tokio::test] +async fn test_issue_async_commands_scan_broken() { + let ctx = TestContext::new(); + let mut con = ctx.async_connection().await.unwrap(); + let mut keys: Vec = (0..100).map(|k| format!("async-key{k}")).collect(); + keys.sort(); + for key in &keys { + let _: () = con.set(key, b"foo").await.unwrap(); + } + + let iter: redis::AsyncIter = con.scan().await.unwrap(); + let mut keys_from_redis: Vec<_> = iter.collect().await; + keys_from_redis.sort(); + assert_eq!(keys, keys_from_redis); + assert_eq!(keys.len(), 100); +} + mod pub_sub { use std::collections::HashMap; use std::time::Duration; @@ -569,4 +610,31 @@ mod pub_sub { }) .unwrap(); } + + #[test] + fn pipe_errors_do_not_affect_subsequent_commands() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + conn.lpush::<&str, &str, ()>("key", "value").await?; + + let res: Result<(String, usize), redis::RedisError> = redis::pipe() + .get("key") // WRONGTYPE + .llen("key") + .query_async(&mut conn) + .await; + + assert!(res.is_err()); + + let list: Vec = conn.lrange("key", 0, -1).await?; + + assert_eq!(list, vec!["value".to_owned()]); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } } diff --git a/redis/tests/test_async_async_std.rs b/redis/tests/test_async_async_std.rs index 23f6863be..d2a300dc1 100644 --- a/redis/tests/test_async_async_std.rs +++ b/redis/tests/test_async_async_std.rs @@ -126,12 +126,12 @@ fn test_pipeline_transaction() { fn test_cmd(con: &MultiplexedConnection, i: i32) -> impl Future> + Send { let mut con = con.clone(); async move { - let key = format!("key{}", i); + let key = format!("key{i}"); let key_2 = key.clone(); - let key2 = format!("key{}_2", i); + let key2 = format!("key{i}_2"); let key2_2 = key2.clone(); - let foo_val = format!("foo{}", i); + let foo_val = format!("foo{i}"); redis::cmd("SET") .arg(&key[..]) @@ -219,7 +219,7 @@ fn test_transaction_multiplexed_connection() { let mut con = con.clone(); async move { let foo_val = i; - let bar_val = format!("bar{}", i); + let bar_val = format!("bar{i}"); let mut pipe = redis::pipe(); pipe.atomic() @@ -261,6 +261,7 @@ fn test_script() { // into Redis and when they need to be loaded in let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); let ctx = TestContext::new(); @@ -273,6 +274,8 @@ fn test_script() { .await?; let val: String = script2.key("key1").invoke_async(&mut con).await?; assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); script1 .key("key1") .arg("bar") @@ -280,6 +283,8 @@ fn test_script() { .await?; let val: String = script2.key("key1").invoke_async(&mut con).await?; assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); Ok::<_, RedisError>(()) }) .unwrap(); diff --git a/redis/tests/test_basic.rs b/redis/tests/test_basic.rs index 4e2544b6d..215053c2d 100644 --- a/redis/tests/test_basic.rs +++ b/redis/tests/test_basic.rs @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::collections::{HashMap, HashSet}; use std::thread::{sleep, spawn}; use std::time::Duration; +use std::vec; use crate::support::*; @@ -790,7 +791,7 @@ fn test_auto_m_versions() { let ctx = TestContext::new(); let mut con = ctx.connection(); - assert_eq!(con.set_multiple(&[("key1", 1), ("key2", 2)]), Ok(())); + assert_eq!(con.mset(&[("key1", 1), ("key2", 2)]), Ok(())); assert_eq!(con.get(&["key1", "key2"]), Ok((1, 2))); assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); assert_eq!(con.get(&vec!["key1", "key2"]), Ok((1, 2))); @@ -1124,7 +1125,7 @@ fn test_object_commands() { "int" ); - assert_eq!(con.object_idletime::<_, i32>("object_key_str").unwrap(), 0); + assert!(con.object_idletime::<_, i32>("object_key_str").unwrap() <= 1); assert_eq!(con.object_refcount::<_, i32>("object_key_str").unwrap(), 1); // Needed for OBJECT FREQ and can't be set before object_idletime @@ -1140,3 +1141,46 @@ fn test_object_commands() { // get after that assert_eq!(con.object_freq::<_, i32>("object_key_str").unwrap(), 1); } + +#[test] +fn test_mget() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let data: Vec = con.mget(&[1]).unwrap(); + assert_eq!(data, vec!["1"]); + + let _: () = con.set(2, "2").unwrap(); + let data: Vec = con.mget(&[1, 2]).unwrap(); + assert_eq!(data, vec!["1", "2"]); + + let data: Vec> = con.mget(&[4]).unwrap(); + assert_eq!(data, vec![None]); + + let data: Vec> = con.mget(&[2, 4]).unwrap(); + assert_eq!(data, vec![Some("2".to_string()), None]); +} + +#[test] +fn test_variable_length_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let keys = vec![1]; + assert_eq!(keys.len(), 1); + let data: Vec = con.get(&keys).unwrap(); + assert_eq!(data, vec!["1"]); +} + +#[test] +fn test_multi_generics() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd(b"set1", vec![5, 42]), Ok(2)); + assert_eq!(con.sadd(999_i64, vec![42, 123]), Ok(2)); + let _: () = con.rename(999_i64, b"set2").unwrap(); + assert_eq!(con.sunionstore("res", &[b"set1", b"set2"]), Ok(3)); +} diff --git a/redis/tests/test_cluster.rs b/redis/tests/test_cluster.rs index 6704ec4d8..bd64fb5e3 100644 --- a/redis/tests/test_cluster.rs +++ b/redis/tests/test_cluster.rs @@ -1,7 +1,12 @@ #![cfg(feature = "cluster")] mod support; +use std::sync::{atomic, Arc}; + use crate::support::*; -use redis::cluster::cluster_pipe; +use redis::{ + cluster::{cluster_pipe, ClusterClient}, + cmd, parse_redis_value, Value, +}; #[test] fn test_cluster_basics() { @@ -29,6 +34,8 @@ fn test_cluster_with_username_and_password() { .username(RedisCluster::username().to_string()) .password(RedisCluster::password().to_string()) }); + cluster.disable_default_user(); + let mut con = cluster.connection(); redis::cmd("SET") @@ -208,8 +215,8 @@ fn test_cluster_pipeline_command_ordering() { let mut queries = Vec::new(); let mut expected = Vec::new(); for i in 0..100 { - queries.push(format!("foo{}", i)); - expected.push(format!("bar{}", i)); + queries.push(format!("foo{i}")); + expected.push(format!("bar{i}")); pipe.set(&queries[i], &expected[i]).ignore(); } pipe.execute(&mut con); @@ -237,8 +244,8 @@ fn test_cluster_pipeline_ordering_with_improper_command() { if i == 5 { pipe.cmd("hset").arg("foo").ignore(); } else { - let query = format!("foo{}", i); - let r = format!("bar{}", i); + let query = format!("foo{i}"); + let r = format!("bar{i}"); pipe.set(&query, &r).ignore(); queries.push(query); expected.push(r); @@ -256,3 +263,173 @@ fn test_cluster_pipeline_ordering_with_improper_command() { let got = pipe.query::>(&mut con).unwrap(); assert_eq!(got, expected); } + +#[test] +fn test_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::Data(b"123".to_vec()))), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); +} + +#[test] +fn test_cluster_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!( + result.map_err(|err| err.to_string()), + Err("An error was signalled by the server: mock".to_string()) + ); + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); +} + +#[test] +fn test_cluster_rebuild_with_extra_nodes() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::Status("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists elswehere (the slot, 123, is unused in the + // implementation) + 0 => Err(parse_redis_value(b"-MOVED 123\r\n")), + // Respond with the new masters + 1 => Err(Ok(Value::Bulk(vec![ + Value::Bulk(vec![ + Value::Int(0), + Value::Int(1), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Bulk(vec![ + Value::Int(2), + Value::Int(16383), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))), + _ => { + // Check that the correct node receives the request after rebuilding + assert_eq!(port, 6380); + Err(Ok(Value::Data(b"123".to_vec()))) + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); +} + +#[test] +fn test_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + + match port { + 6380 => Err(Ok(Value::Data(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::Status("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("SET") + .arg("test") + .arg("123") + .query::>(&mut connection); + assert_eq!(value, Ok(Some(Value::Status("OK".to_owned())))); +} diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs new file mode 100644 index 000000000..e74742b7d --- /dev/null +++ b/redis/tests/test_cluster_async.rs @@ -0,0 +1,525 @@ +#![cfg(feature = "cluster-async")] +mod support; +use std::sync::{ + atomic::{self, AtomicI32}, + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use futures::prelude::*; +use futures::stream; +use once_cell::sync::Lazy; +use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cluster::ClusterClient, + cluster_async::Connect, + cmd, parse_redis_value, AsyncCommands, Cmd, InfoDict, IntoConnectionInfo, RedisError, + RedisFuture, RedisResult, Script, Value, +}; + +use crate::support::*; + +#[test] +fn test_async_cluster_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + +#[test] +fn test_async_cluster_basic_eval() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + let res: String = cmd("EVAL") + .arg(r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#) + .arg(1) + .arg("key") + .arg("test") + .query_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + +#[ignore] // TODO Handle running SCRIPT LOAD on all masters +#[test] +fn test_async_cluster_basic_script() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + let res: String = Script::new( + r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#, + ) + .key("key") + .arg("test") + .invoke_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + +#[ignore] // TODO Handle pipe where the keys do not all go to the same node +#[test] +fn test_async_cluster_basic_pipe() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + let mut pipe = redis::pipe(); + pipe.add_command(cmd("SET").arg("test").arg("test_data").clone()); + pipe.add_command(cmd("SET").arg("test3").arg("test_data3").clone()); + pipe.query_async(&mut connection).await?; + let res: String = connection.get("test").await?; + assert_eq!(res, "test_data"); + let res: String = connection.get("test3").await?; + assert_eq!(res, "test_data3"); + Ok::<_, RedisError>(()) + }) + .unwrap() +} + +#[test] +fn test_async_cluster_basic_failover() { + block_on_all(async move { + test_failover(&TestClusterContext::new(6, 1), 10, 123).await; + Ok::<_, RedisError>(()) + }) + .unwrap() +} + +async fn do_failover(redis: &mut redis::aio::MultiplexedConnection) -> Result<(), anyhow::Error> { + cmd("CLUSTER").arg("FAILOVER").query_async(redis).await?; + Ok(()) +} + +async fn test_failover(env: &TestClusterContext, requests: i32, value: i32) { + let completed = Arc::new(AtomicI32::new(0)); + + let connection = env.async_connection().await; + let mut node_conns: Vec = Vec::new(); + + 'outer: loop { + node_conns.clear(); + let cleared_nodes = async { + for server in env.cluster.iter_servers() { + let addr = server.client_addr(); + let client = redis::Client::open(server.connection_info()) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + let mut conn = client + .get_multiplexed_async_connection() + .await + .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); + + let info: InfoDict = redis::Cmd::new() + .arg("INFO") + .query_async(&mut conn) + .await + .expect("INFO"); + let role: String = info.get("role").expect("cluster role"); + + if role == "master" { + tokio::time::timeout(std::time::Duration::from_secs(3), async { + Ok(redis::Cmd::new() + .arg("FLUSHALL") + .query_async(&mut conn) + .await?) + }) + .await + .unwrap_or_else(|err| Err(anyhow::Error::from(err)))?; + } + + node_conns.push(conn); + } + Ok::<_, anyhow::Error>(()) + } + .await; + match cleared_nodes { + Ok(()) => break 'outer, + Err(err) => { + // Failed to clear the databases, retry + log::warn!("{}", err); + } + } + } + + (0..requests + 1) + .map(|i| { + let mut connection = connection.clone(); + let mut node_conns = node_conns.clone(); + let completed = completed.clone(); + async move { + if i == requests / 2 { + // Failover all the nodes, error only if all the failover requests error + node_conns + .iter_mut() + .map(do_failover) + .collect::>() + .fold( + Err(anyhow::anyhow!("None")), + |acc: Result<(), _>, result: Result<(), _>| async move { + acc.or(result) + }, + ) + .await + } else { + let key = format!("test-{value}-{i}"); + cmd("SET") + .arg(&key) + .arg(i) + .clone() + .query_async(&mut connection) + .await?; + let res: i32 = cmd("GET") + .arg(key) + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, i); + completed.fetch_add(1, Ordering::SeqCst); + Ok::<_, anyhow::Error>(()) + } + } + }) + .collect::>() + .try_collect() + .await + .unwrap_or_else(|e| panic!("{e}")); + + assert_eq!( + completed.load(Ordering::SeqCst), + requests, + "Some requests never completed!" + ); +} + +static ERROR: Lazy = Lazy::new(Default::default); + +#[derive(Clone)] +struct ErrorConnection { + inner: MultiplexedConnection, +} + +impl Connect for ErrorConnection { + fn connect<'a, T>(info: T) -> RedisFuture<'a, Self> + where + T: IntoConnectionInfo + Send + 'a, + { + Box::pin(async { + let inner = MultiplexedConnection::connect(info).await?; + Ok(ErrorConnection { inner }) + }) + } +} + +impl ConnectionLike for ErrorConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + if ERROR.load(Ordering::SeqCst) { + Box::pin(async move { Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) }) + } else { + self.inner.req_packed_command(cmd) + } + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + self.inner.req_packed_commands(pipeline, offset, count) + } + + fn get_db(&self) -> i64 { + self.inner.get_db() + } +} + +#[test] +fn test_async_cluster_error_in_inner_connection() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut con = cluster.async_generic_connection::().await; + + ERROR.store(false, Ordering::SeqCst); + let r: Option = con.get("test").await?; + assert_eq!(r, None::); + + ERROR.store(true, Ordering::SeqCst); + + let result: RedisResult<()> = con.get("test").await; + assert_eq!( + result, + Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) + ); + + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + +#[test] +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +fn test_async_cluster_async_std_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all_using_async_std(async { + let mut connection = cluster.async_connection().await; + redis::cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + redis::cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .map_ok(|res: String| { + assert_eq!(res, "test_data"); + }) + .await + }) + .unwrap(); +} + +#[test] +fn test_async_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::Data(b"123".to_vec()))), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); +} + +#[test] +fn test_async_cluster_tryagain_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!( + result.map_err(|err| err.to_string()), + Err("An error was signalled by the server: mock".to_string()) + ); + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); +} + +#[test] +fn test_async_cluster_rebuild_with_extra_nodes() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::Status("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists elswehere (the slot, 123, is unused in the + // implementation) + 0 => Err(parse_redis_value(b"-MOVED 123\r\n")), + // Respond with the new masters + 1 => Err(Ok(Value::Bulk(vec![ + Value::Bulk(vec![ + Value::Int(0), + Value::Int(1), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Bulk(vec![ + Value::Int(2), + Value::Int(16383), + Value::Bulk(vec![ + Value::Data(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))), + _ => { + // Check that the correct node receives the request after rebuilding + assert_eq!(port, 6380); + Err(Ok(Value::Data(b"123".to_vec()))) + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); +} + +#[test] +fn test_async_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + + match port { + 6380 => Err(Ok(Value::Data(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::Status("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("SET") + .arg("test") + .arg("123") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::Status("OK".to_owned())))); +} + +#[test] +fn test_async_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }); + cluster.disable_default_user(); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} diff --git a/redis/tests/test_json.rs b/redis/tests/test_module_json.rs similarity index 88% rename from redis/tests/test_json.rs rename to redis/tests/test_module_json.rs index 09fed8979..49d3e51f5 100644 --- a/redis/tests/test_json.rs +++ b/redis/tests/test_module_json.rs @@ -20,7 +20,7 @@ use serde_json::{self, json}; const TEST_KEY: &str = "my_json"; #[test] -fn test_json_serialize_error() { +fn test_module_json_serialize_error() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -52,7 +52,7 @@ fn test_json_serialize_error() { } #[test] -fn test_json_arr_append() { +fn test_module_json_arr_append() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -70,7 +70,7 @@ fn test_json_arr_append() { } #[test] -fn test_json_arr_index() { +fn test_module_json_arr_index() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -94,13 +94,14 @@ fn test_json_arr_index() { assert_eq!(update_initial, Ok(true)); - let json_arrindex_2: RedisResult = con.json_arr_index_ss(TEST_KEY, "$..a", &2i64, 0, 0); + let json_arrindex_2: RedisResult = + con.json_arr_index_ss(TEST_KEY, "$..a", &2i64, &0, &0); assert_eq!(json_arrindex_2, Ok(Bulk(vec![Int(1i64), Nil]))); } #[test] -fn test_json_arr_insert() { +fn test_module_json_arr_insert() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -130,7 +131,7 @@ fn test_json_arr_insert() { } #[test] -fn test_json_arr_len() { +fn test_module_json_arr_len() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -160,7 +161,7 @@ fn test_json_arr_len() { } #[test] -fn test_json_arr_pop() { +fn test_module_json_arr_pop() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -200,7 +201,7 @@ fn test_json_arr_pop() { } #[test] -fn test_json_arr_trim() { +fn test_module_json_arr_trim() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -230,7 +231,7 @@ fn test_json_arr_trim() { } #[test] -fn test_json_clear() { +fn test_module_json_clear() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -255,7 +256,7 @@ fn test_json_clear() { } #[test] -fn test_json_del() { +fn test_module_json_del() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -273,7 +274,7 @@ fn test_json_del() { } #[test] -fn test_json_get() { +fn test_module_json_get() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -289,23 +290,27 @@ fn test_json_get() { assert_eq!(json_get, Ok("[3,null]".into())); - let json_get_multi: RedisResult = con.json_get(TEST_KEY, "..a $..b"); + let json_get_multi: RedisResult = con.json_get(TEST_KEY, vec!["..a", "$..b"]); - assert_eq!(json_get_multi, Ok("2".into())); + if json_get_multi != Ok("{\"$..b\":[3,null],\"..a\":[2,4]}".into()) + && json_get_multi != Ok("{\"..a\":[2,4],\"$..b\":[3,null]}".into()) + { + panic!("test_error: incorrect response from json_get_multi"); + } } #[test] -fn test_json_mget() { +fn test_module_json_mget() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); let set_initial_a: RedisResult = con.json_set( - format!("{}-a", TEST_KEY), + format!("{TEST_KEY}-a"), "$", &json!({"a":1i64, "b": 2i64, "nested": {"a": 3i64, "b": null}}), ); let set_initial_b: RedisResult = con.json_set( - format!("{}-b", TEST_KEY), + format!("{TEST_KEY}-b"), "$", &json!({"a":4i64, "b": 5i64, "nested": {"a": 6i64, "b": null}}), ); @@ -313,8 +318,8 @@ fn test_json_mget() { assert_eq!(set_initial_a, Ok(true)); assert_eq!(set_initial_b, Ok(true)); - let json_mget: RedisResult = con.json_mget( - vec![format!("{}-a", TEST_KEY), format!("{}-b", TEST_KEY)], + let json_mget: RedisResult = con.json_get( + vec![format!("{TEST_KEY}-a"), format!("{TEST_KEY}-b")], "$..a", ); @@ -328,7 +333,7 @@ fn test_json_mget() { } #[test] -fn test_json_numincrby() { +fn test_module_json_num_incr_by() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -340,19 +345,19 @@ fn test_json_numincrby() { assert_eq!(set_initial, Ok(true)); - let json_numincrby_a: RedisResult = con.json_numincrby(TEST_KEY, "$.a", 2); + let json_numincrby_a: RedisResult = con.json_num_incr_by(TEST_KEY, "$.a", 2); // cannot increment a string assert_eq!(json_numincrby_a, Ok("[null]".into())); - let json_numincrby_b: RedisResult = con.json_numincrby(TEST_KEY, "$..a", 2); + let json_numincrby_b: RedisResult = con.json_num_incr_by(TEST_KEY, "$..a", 2); // however numbers can be incremented assert_eq!(json_numincrby_b, Ok("[null,4,7,null]".into())); } #[test] -fn test_json_objkeys() { +fn test_module_json_obj_keys() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -364,7 +369,7 @@ fn test_json_objkeys() { assert_eq!(set_initial, Ok(true)); - let json_objkeys: RedisResult = con.json_objkeys(TEST_KEY, "$..a"); + let json_objkeys: RedisResult = con.json_obj_keys(TEST_KEY, "$..a"); assert_eq!( json_objkeys, @@ -379,7 +384,7 @@ fn test_json_objkeys() { } #[test] -fn test_json_objlen() { +fn test_module_json_obj_len() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -391,13 +396,13 @@ fn test_json_objlen() { assert_eq!(set_initial, Ok(true)); - let json_objlen: RedisResult = con.json_objlen(TEST_KEY, "$..a"); + let json_objlen: RedisResult = con.json_obj_len(TEST_KEY, "$..a"); assert_eq!(json_objlen, Ok(Bulk(vec![Nil, Int(2)]))); } #[test] -fn test_json_set() { +fn test_module_json_set() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -407,7 +412,7 @@ fn test_json_set() { } #[test] -fn test_json_strappend() { +fn test_module_json_str_append() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -419,7 +424,7 @@ fn test_json_strappend() { assert_eq!(set_initial, Ok(true)); - let json_strappend: RedisResult = con.json_strappend(TEST_KEY, "$..a", "\"baz\""); + let json_strappend: RedisResult = con.json_str_append(TEST_KEY, "$..a", "\"baz\""); assert_eq!(json_strappend, Ok(Bulk(vec![Int(6), Int(8), Nil]))); @@ -432,7 +437,7 @@ fn test_json_strappend() { } #[test] -fn test_json_strlen() { +fn test_module_json_str_len() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -444,13 +449,13 @@ fn test_json_strlen() { assert_eq!(set_initial, Ok(true)); - let json_strlen: RedisResult = con.json_strlen(TEST_KEY, "$..a"); + let json_strlen: RedisResult = con.json_str_len(TEST_KEY, "$..a"); assert_eq!(json_strlen, Ok(Bulk(vec![Int(3), Int(5), Nil]))); } #[test] -fn test_json_toggle() { +fn test_module_json_toggle() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); @@ -466,7 +471,7 @@ fn test_json_toggle() { } #[test] -fn test_json_type() { +fn test_module_json_type() { let ctx = TestContext::with_modules(&[Module::Json]); let mut con = ctx.connection(); diff --git a/redis/tests/test_streams.rs b/redis/tests/test_streams.rs index b58c8de94..3f297324a 100644 --- a/redis/tests/test_streams.rs +++ b/redis/tests/test_streams.rs @@ -361,6 +361,42 @@ fn test_xadd_maxlen_map() { assert_eq!(reply.ids[2].get("idx"), Some("9".to_string())); } +#[test] +fn test_xread_options_deleted_pel_entry() { + // Test xread_options behaviour with deleted entry + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h1", "w1")]); + // read the pending items for this key & group + let result: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h2", "w2")]); + let result_deleted_entry: StreamReadReply = con + .xread_options( + &["k1"], + &["0"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!( + result.keys[0].ids.len(), + result_deleted_entry.keys[0].ids.len() + ); + assert_eq!( + result.keys[0].ids[0].id, + result_deleted_entry.keys[0].ids[0].id + ); +} #[test] fn test_xclaim() { // Tests the following commands.... diff --git a/redis/tests/test_types.rs b/redis/tests/test_types.rs index 8d6f65402..281bf3d9e 100644 --- a/redis/tests/test_types.rs +++ b/redis/tests/test_types.rs @@ -74,6 +74,42 @@ fn test_vec() { assert_eq!(v, Ok(vec![1i32, 2, 3])); } +#[test] +fn test_single_bool_vec() { + use redis::{FromRedisValue, Value}; + + let v = FromRedisValue::from_redis_value(&Value::Data("1".into())); + + assert_eq!(v, Ok(vec![true])); +} + +#[test] +fn test_single_i32_vec() { + use redis::{FromRedisValue, Value}; + + let v = FromRedisValue::from_redis_value(&Value::Data("1".into())); + + assert_eq!(v, Ok(vec![1i32])); +} + +#[test] +fn test_single_u32_vec() { + use redis::{FromRedisValue, Value}; + + let v = FromRedisValue::from_redis_value(&Value::Data("42".into())); + + assert_eq!(v, Ok(vec![42u32])); +} + +#[test] +fn test_single_string_vec() { + use redis::{FromRedisValue, Value}; + + let v = FromRedisValue::from_redis_value(&Value::Data("1".into())); + + assert_eq!(v, Ok(vec!["1".to_string()])); +} + #[test] fn test_tuple() { use redis::{FromRedisValue, Value}; @@ -230,6 +266,7 @@ fn test_types_to_redis_args() { use redis::ToRedisArgs; use std::collections::BTreeMap; use std::collections::BTreeSet; + use std::collections::HashMap; use std::collections::HashSet; assert!(!5i32.to_redis_args().is_empty()); @@ -258,4 +295,12 @@ fn test_types_to_redis_args() { .collect::>() .to_redis_args() .is_empty()); + + // this can also be used on something HMSET + assert!(![("d", 8), ("e", 9), ("f", 10)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); }