rustis/client/
client.rs

1#[cfg(test)]
2use crate::commands::DebugCommands;
3#[cfg(feature = "redis-graph")]
4use crate::commands::GraphCommands;
5#[cfg(feature = "redis-json")]
6use crate::commands::JsonCommands;
7#[cfg(feature = "redis-search")]
8use crate::commands::SearchCommands;
9#[cfg(feature = "redis-time-series")]
10use crate::commands::TimeSeriesCommands;
11#[cfg(feature = "redis-bloom")]
12use crate::commands::{
13    BloomCommands, CountMinSketchCommands, CuckooCommands, TDigestCommands, TopKCommands,
14};
15use crate::{
16    client::{
17        ClientState, ClientTrackingInvalidationStream, IntoConfig, Message, MonitorStream,
18        Pipeline, PreparedCommand, PubSubStream, Transaction,
19    },
20    commands::{
21        BitmapCommands, BlockingCommands, ClusterCommands, ConnectionCommands, GenericCommands,
22        GeoCommands, HashCommands, HyperLogLogCommands, InternalPubSubCommands, ListCommands,
23        PubSubCommands, ScriptingCommands, SentinelCommands, ServerCommands, SetCommands,
24        SortedSetCommands, StreamCommands, StringCommands, TransactionCommands,
25    },
26    network::{
27        timeout, JoinHandle, MsgSender, NetworkHandler, PubSubReceiver, PubSubSender, PushReceiver,
28        PushSender, ReconnectReceiver, ReconnectSender, ResultReceiver, ResultSender,
29        ResultsReceiver, ResultsSender,
30    },
31    resp::{cmd, Command, CommandArgs, RespBuf, Response, SingleArg, SingleArgCollection},
32    Error, Future, Result,
33};
34use futures_channel::{mpsc, oneshot};
35use futures_util::Stream;
36use log::{info, trace};
37use serde::de::DeserializeOwned;
38use std::{
39    future::IntoFuture,
40    sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
41    time::Duration,
42};
43
44/// Client with a unique connection to a Redis server.
45#[derive(Clone)]
46pub struct Client {
47    msg_sender: Arc<Option<MsgSender>>,
48    network_task_join_handle: Arc<Option<JoinHandle<()>>>,
49    reconnect_sender: ReconnectSender,
50    client_state: Arc<RwLock<ClientState>>,
51    command_timeout: Duration,
52    retry_on_error: bool,
53}
54
55impl Drop for Client {
56    /// if this client is the last client on the shared connection, the channel to send messages
57    /// to the underlying network handler will be closed explicitely
58    fn drop(&mut self) {
59        let mut network_task_join_handle: Arc<Option<JoinHandle<()>>> = Arc::new(None);
60        std::mem::swap(
61            &mut network_task_join_handle,
62            &mut self.network_task_join_handle,
63        );
64
65        // stop the network loop if we are the last reference to its handle
66        if Arc::try_unwrap(network_task_join_handle).is_ok() {
67            let mut msg_sender: Arc<Option<MsgSender>> = Arc::new(None);
68            std::mem::swap(&mut msg_sender, &mut self.msg_sender);
69
70            if let Ok(Some(msg_sender)) = Arc::try_unwrap(msg_sender) {
71                // the network loop will automatically ends when it detects the sender bound has been closed
72                msg_sender.close_channel();
73            }
74        };
75    }
76}
77
78impl Client {
79    /// Connects asynchronously to the Redis server.
80    ///
81    /// # Errors
82    /// Any Redis driver [`Error`](crate::Error) that occurs during the connection operation
83    #[inline]
84    pub async fn connect(config: impl IntoConfig) -> Result<Self> {
85        let config = config.into_config()?;
86        let command_timeout = config.command_timeout;
87        let retry_on_error = config.retry_on_error;
88        let (msg_sender, network_task_join_handle, reconnect_sender) =
89            NetworkHandler::connect(config.into_config()?).await?;
90
91        Ok(Self {
92            msg_sender: Arc::new(Some(msg_sender)),
93            network_task_join_handle: Arc::new(Some(network_task_join_handle)),
94            reconnect_sender,
95            client_state: Arc::new(RwLock::new(ClientState::new())),
96            command_timeout,
97            retry_on_error,
98        })
99    }
100
101    /// if this client is the last client on the shared connection, the channel to send messages
102    /// to the underlying network handler will be closed explicitely.
103    ///
104    /// Then, this function will await for the network handler to be ended
105    pub async fn close(mut self) -> Result<()> {
106        let mut network_task_join_handle: Arc<Option<JoinHandle<()>>> = Arc::new(None);
107        std::mem::swap(
108            &mut network_task_join_handle,
109            &mut self.network_task_join_handle,
110        );
111
112        // stop the network loop if we are the last reference to its handle
113        if let Ok(Some(network_task_join_handle)) = Arc::try_unwrap(network_task_join_handle) {
114            let mut msg_sender: Arc<Option<MsgSender>> = Arc::new(None);
115            std::mem::swap(&mut msg_sender, &mut self.msg_sender);
116
117            if let Ok(Some(msg_sender)) = Arc::try_unwrap(msg_sender) {
118                // the network loop will automatically ends when it detects the sender bound has been closed
119                msg_sender.close_channel();
120                network_task_join_handle.await?;
121            }
122        };
123
124        Ok(())
125    }
126
127    /// Used to receive notifications when the client reconnects to the Redis server.
128    ///
129    /// To turn this receiver into a Stream, you can use the
130    /// [`BroadcastStream`](https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.BroadcastStream.html) wrapper.
131    pub fn on_reconnect(&self) -> ReconnectReceiver {
132        self.reconnect_sender.subscribe()
133    }
134
135    /// Give an immutable generic access to attach any state to a client instance
136    pub fn get_client_state(&self) -> RwLockReadGuard<ClientState> {
137        self.client_state.read().unwrap()
138    }
139
140    /// Give a mutable generic access to attach any state to a client instance
141    pub fn get_client_state_mut(&self) -> RwLockWriteGuard<ClientState> {
142        self.client_state.write().unwrap()
143    }
144
145    /// Send an arbitrary command to the server.
146    ///
147    /// This is used primarily intended for implementing high level commands API
148    /// but may also be used to provide access to new features that lack a direct API.
149    ///
150    /// # Arguments
151    /// * `command` - generic [`Command`](crate::resp::Command) meant to be sent to the Redis server.
152    /// * `retry_on_error` - retry to send the command on network error.
153    ///   * `None` - default behaviour defined in [`Config::retry_on_error`](crate::client::Config::retry_on_error)
154    ///   * `Some(true)` - retry sending command on network error
155    ///   * `Some(false)` - do not retry sending command on network error
156    ///
157    /// # Errors
158    /// Any Redis driver [`Error`](crate::Error) that occurs during the send operation
159    ///
160    /// # Example
161    /// ```
162    /// use rustis::{client::Client, resp::cmd, Result};
163    ///
164    /// #[cfg_attr(feature = "tokio-runtime", tokio::main)]
165    /// #[cfg_attr(feature = "async-std-runtime", async_std::main)]
166    /// async fn main() -> Result<()> {
167    ///     let client = Client::connect("127.0.0.1:6379").await?;
168    ///
169    ///     client
170    ///         .send(
171    ///             cmd("MSET")
172    ///                 .arg("key1")
173    ///                 .arg("value1")
174    ///                 .arg("key2")
175    ///                 .arg("value2")
176    ///                  .arg("key3")
177    ///                 .arg("value3")
178    ///                 .arg("key4")
179    ///                 .arg("value4"),
180    ///             None,
181    ///         )
182    ///         .await?
183    ///         .to::<()>()?;
184    ///
185    ///     let values: Vec<String> = client
186    ///         .send(
187    ///             cmd("MGET").arg("key1").arg("key2").arg("key3").arg("key4"),
188    ///             None,
189    ///         )
190    ///         .await?
191    ///         .to()?;
192    ///
193    ///     assert_eq!(vec!["value1".to_owned(), "value2".to_owned(), "value3".to_owned(), "value4".to_owned()], values);
194    ///
195    ///     Ok(())
196    /// }
197    /// ```
198
199    #[inline]
200    pub async fn send(&self, command: Command, retry_on_error: Option<bool>) -> Result<RespBuf> {
201        let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
202        let message = Message::single(
203            command,
204            result_sender,
205            retry_on_error.unwrap_or(self.retry_on_error),
206        );
207        self.send_message(message)?;
208
209        if self.command_timeout != Duration::ZERO {
210            timeout(self.command_timeout, result_receiver).await??
211        } else {
212            result_receiver.await?
213        }
214    }
215
216    /// Send command to the Redis server and forget its response.
217    ///
218    /// # Arguments
219    /// * `command` - generic [`Command`](crate::resp::Command) meant to be sent to the Redis server.
220    /// * `retry_on_error` - retry to send the command on network error.
221    ///   * `None` - default behaviour defined in [`Config::retry_on_error`](crate::client::Config::retry_on_error)
222    ///   * `Some(true)` - retry sending command on network error
223    ///   * `Some(false)` - do not retry sending command on network error
224    ///
225    /// # Errors
226    /// Any Redis driver [`Error`](crate::Error) that occurs during the send operation
227    #[inline]
228    pub fn send_and_forget(&self, command: Command, retry_on_error: Option<bool>) -> Result<()> {
229        let message =
230            Message::single_forget(command, retry_on_error.unwrap_or(self.retry_on_error));
231        self.send_message(message)?;
232        Ok(())
233    }
234
235    /// Send a batch of commands to the Redis server.
236    ///
237    /// # Arguments
238    /// * `commands` - batch of generic [`Command`](crate::resp::Command)s meant to be sent to the Redis server.
239    /// * `retry_on_error` - retry to send the command batch on network error.
240    ///   * `None` - default behaviour defined in [`Config::retry_on_error`](crate::client::Config::retry_on_error)
241    ///   * `Some(true)` - retry sending batch on network error
242    ///   * `Some(false)` - do not retry sending batch on network error
243    ///
244    /// # Errors
245    /// Any Redis driver [`Error`](crate::Error) that occurs during the send operation
246    #[inline]
247    pub async fn send_batch(
248        &self,
249        commands: Vec<Command>,
250        retry_on_error: Option<bool>,
251    ) -> Result<Vec<RespBuf>> {
252        let (results_sender, results_receiver): (ResultsSender, ResultsReceiver) =
253            oneshot::channel();
254        let message = Message::batch(
255            commands,
256            results_sender,
257            retry_on_error.unwrap_or(self.retry_on_error),
258        );
259        self.send_message(message)?;
260
261        if self.command_timeout != Duration::ZERO {
262            timeout(self.command_timeout, results_receiver).await??
263        } else {
264            results_receiver.await?
265        }
266    }
267
268    #[inline]
269    fn send_message(&self, message: Message) -> Result<()> {
270        if let Some(msg_sender) = &self.msg_sender as &Option<MsgSender> {
271            trace!("Will enqueue message: {message:?}");
272            Ok(msg_sender.unbounded_send(message).map_err(|e| {
273                info!("{}", e.to_string());
274                Error::Client("Disconnected from server".to_string())
275            })?)
276        } else {
277            Err(Error::Client(
278                "Invalid channel to send messages to the network handler".to_owned(),
279            ))
280        }
281    }
282
283    /// Create a new transaction
284    #[inline]
285    pub fn create_transaction(&self) -> Transaction {
286        Transaction::new(self.clone())
287    }
288
289    /// Create a new pipeline
290    #[inline]
291    pub fn create_pipeline(&self) -> Pipeline {
292        Pipeline::new(self)
293    }
294
295    /// Create a new pub sub stream with no upfront subscription
296    #[inline]
297    pub fn create_pub_sub(&self) -> PubSubStream {
298        let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
299        PubSubStream::new(pub_sub_sender, pub_sub_receiver, self.clone())
300    }
301
302    pub fn create_client_tracking_invalidation_stream(
303        &self,
304    ) -> Result<impl Stream<Item = Vec<String>>> {
305        let (push_sender, push_receiver): (PushSender, PushReceiver) = mpsc::unbounded();
306        let message = Message::client_tracking_invalidation(push_sender);
307        self.send_message(message)?;
308        Ok(ClientTrackingInvalidationStream::new(push_receiver))
309    }
310
311    pub(crate) async fn subscribe_from_pub_sub_sender(
312        &self,
313        channels: &CommandArgs,
314        pub_sub_sender: &PubSubSender,
315    ) -> Result<()> {
316        let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
317
318        let pub_sub_senders = channels
319            .into_iter()
320            .map(|c| (c.to_vec(), pub_sub_sender.clone()))
321            .collect::<Vec<_>>();
322
323        let message = Message::pub_sub(
324            cmd("SUBSCRIBE").arg(channels.clone()),
325            result_sender,
326            pub_sub_senders,
327        );
328
329        self.send_message(message)?;
330
331        result_receiver.await??.to::<()>()
332    }
333
334    pub(crate) async fn psubscribe_from_pub_sub_sender(
335        &self,
336        patterns: &CommandArgs,
337        pub_sub_sender: &PubSubSender,
338    ) -> Result<()> {
339        let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
340
341        let pub_sub_senders = patterns
342            .into_iter()
343            .map(|c| (c.to_vec(), pub_sub_sender.clone()))
344            .collect::<Vec<_>>();
345
346        let message = Message::pub_sub(
347            cmd("PSUBSCRIBE").arg(patterns.clone()),
348            result_sender,
349            pub_sub_senders,
350        );
351
352        self.send_message(message)?;
353
354        result_receiver.await??.to::<()>()
355    }
356
357    pub(crate) async fn ssubscribe_from_pub_sub_sender(
358        &self,
359        shardchannels: &CommandArgs,
360        pub_sub_sender: &PubSubSender,
361    ) -> Result<()> {
362        let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
363
364        let pub_sub_senders = shardchannels
365            .into_iter()
366            .map(|c| (c.to_vec(), pub_sub_sender.clone()))
367            .collect::<Vec<_>>();
368
369        let message = Message::pub_sub(
370            cmd("SSUBSCRIBE").arg(shardchannels.clone()),
371            result_sender,
372            pub_sub_senders,
373        );
374
375        self.send_message(message)?;
376
377        result_receiver.await??.to::<()>()
378    }
379}
380
381/// Extension trait dedicated to [`PreparedCommand`](crate::client::PreparedCommand)
382/// to add specific methods for the [`Client`](crate::client::Client) executor
383pub trait ClientPreparedCommand<'a, R> {
384    /// Send command and forget its response
385    ///
386    /// # Errors
387    /// Any Redis driver [`Error`](crate::Error) that occur during the send operation
388    fn forget(self) -> Result<()>;
389}
390
391impl<'a, R: Response> ClientPreparedCommand<'a, R> for PreparedCommand<'a, &'a Client, R> {
392    /// Send command and forget its response
393    ///
394    /// # Errors
395    /// Any Redis driver [`Error`](crate::Error) that occur during the send operation
396    fn forget(self) -> Result<()> {
397        self.executor
398            .send_and_forget(self.command, self.retry_on_error)
399    }
400}
401
402impl<'a, R> IntoFuture for PreparedCommand<'a, &'a Client, R>
403where
404    R: DeserializeOwned + Send + 'a,
405{
406    type Output = Result<R>;
407    type IntoFuture = Future<'a, R>;
408
409    fn into_future(self) -> Self::IntoFuture {
410        Box::pin(async move {
411            if let Some(custom_converter) = self.custom_converter {
412                let command_for_result = self.command.clone();
413                let result = self
414                    .executor
415                    .send(self.command, self.retry_on_error)
416                    .await?;
417                custom_converter(result, command_for_result, self.executor).await
418            } else {
419                let result = self
420                    .executor
421                    .send(self.command, self.retry_on_error)
422                    .await?;
423                result.to()
424            }
425        })
426    }
427}
428
429impl<'a> BitmapCommands<'a> for &'a Client {}
430#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
431#[cfg(feature = "redis-bloom")]
432impl<'a> BloomCommands<'a> for &'a Client {}
433impl<'a> ClusterCommands<'a> for &'a Client {}
434#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
435#[cfg(feature = "redis-bloom")]
436impl<'a> CountMinSketchCommands<'a> for &'a Client {}
437#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
438#[cfg(feature = "redis-bloom")]
439impl<'a> CuckooCommands<'a> for &'a Client {}
440impl<'a> ConnectionCommands<'a> for &'a Client {}
441#[cfg(test)]
442impl<'a> DebugCommands<'a> for &'a Client {}
443impl<'a> GenericCommands<'a> for &'a Client {}
444impl<'a> GeoCommands<'a> for &'a Client {}
445#[cfg_attr(docsrs, doc(cfg(feature = "redis-graph")))]
446#[cfg(feature = "redis-graph")]
447impl<'a> GraphCommands<'a> for &'a Client {}
448impl<'a> HashCommands<'a> for &'a Client {}
449impl<'a> HyperLogLogCommands<'a> for &'a Client {}
450impl<'a> InternalPubSubCommands<'a> for &'a Client {}
451#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))]
452#[cfg(feature = "redis-json")]
453impl<'a> JsonCommands<'a> for &'a Client {}
454impl<'a> ListCommands<'a> for &'a Client {}
455impl<'a> ScriptingCommands<'a> for &'a Client {}
456#[cfg_attr(docsrs, doc(cfg(feature = "redis-search")))]
457#[cfg(feature = "redis-search")]
458impl<'a> SearchCommands<'a> for &'a Client {}
459impl<'a> SentinelCommands<'a> for &'a Client {}
460impl<'a> ServerCommands<'a> for &'a Client {}
461impl<'a> SetCommands<'a> for &'a Client {}
462impl<'a> SortedSetCommands<'a> for &'a Client {}
463impl<'a> StreamCommands<'a> for &'a Client {}
464impl<'a> StringCommands<'a> for &'a Client {}
465#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
466#[cfg(feature = "redis-bloom")]
467impl<'a> TDigestCommands<'a> for &'a Client {}
468#[cfg_attr(docsrs, doc(cfg(feature = "redis-time-series")))]
469#[cfg(feature = "redis-time-series")]
470impl<'a> TimeSeriesCommands<'a> for &'a Client {}
471impl<'a> TransactionCommands<'a> for &'a Client {}
472#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
473#[cfg(feature = "redis-bloom")]
474impl<'a> TopKCommands<'a> for &'a Client {}
475
476impl<'a> PubSubCommands<'a> for &'a Client {
477    #[inline]
478    async fn subscribe<C, CC>(self, channels: CC) -> Result<PubSubStream>
479    where
480        C: SingleArg + Send + 'a,
481        CC: SingleArgCollection<C>,
482    {
483        let channels = CommandArgs::default().arg(channels).build();
484
485        let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
486
487        self.subscribe_from_pub_sub_sender(&channels, &pub_sub_sender)
488            .await?;
489
490        Ok(PubSubStream::from_channels(
491            channels,
492            pub_sub_sender,
493            pub_sub_receiver,
494            self.clone(),
495        ))
496    }
497
498    #[inline]
499    async fn psubscribe<P, PP>(self, patterns: PP) -> Result<PubSubStream>
500    where
501        P: SingleArg + Send + 'a,
502        PP: SingleArgCollection<P>,
503    {
504        let patterns = CommandArgs::default().arg(patterns).build();
505
506        let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
507
508        self.psubscribe_from_pub_sub_sender(&patterns, &pub_sub_sender)
509            .await?;
510
511        Ok(PubSubStream::from_patterns(
512            patterns,
513            pub_sub_sender,
514            pub_sub_receiver,
515            self.clone(),
516        ))
517    }
518
519    #[inline]
520    async fn ssubscribe<C, CC>(self, shardchannels: CC) -> Result<PubSubStream>
521    where
522        C: SingleArg + Send + 'a,
523        CC: SingleArgCollection<C>,
524    {
525        let shardchannels = CommandArgs::default().arg(shardchannels).build();
526
527        let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
528
529        self.ssubscribe_from_pub_sub_sender(&shardchannels, &pub_sub_sender)
530            .await?;
531
532        Ok(PubSubStream::from_shardchannels(
533            shardchannels,
534            pub_sub_sender,
535            pub_sub_receiver,
536            self.clone(),
537        ))
538    }
539}
540
541impl<'a> BlockingCommands<'a> for &'a Client {
542    async fn monitor(self) -> Result<MonitorStream> {
543        let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
544        let (push_sender, push_receiver): (PushSender, PushReceiver) = mpsc::unbounded();
545
546        let message = Message::monitor(cmd("MONITOR"), result_sender, push_sender);
547
548        self.send_message(message)?;
549
550        let _bytes = result_receiver.await??;
551        Ok(MonitorStream::new(push_receiver, self.clone()))
552    }
553}