1#[cfg(test)]
2use crate::commands::DebugCommands;
3#[cfg(feature = "redis-graph")]
4use crate::commands::GraphCommands;
5#[cfg(feature = "redis-json")]
6use crate::commands::JsonCommands;
7#[cfg(feature = "redis-search")]
8use crate::commands::SearchCommands;
9#[cfg(feature = "redis-time-series")]
10use crate::commands::TimeSeriesCommands;
11#[cfg(feature = "redis-bloom")]
12use crate::commands::{
13 BloomCommands, CountMinSketchCommands, CuckooCommands, TDigestCommands, TopKCommands,
14};
15use crate::{
16 client::{
17 ClientState, ClientTrackingInvalidationStream, IntoConfig, Message, MonitorStream,
18 Pipeline, PreparedCommand, PubSubStream, Transaction,
19 },
20 commands::{
21 BitmapCommands, BlockingCommands, ClusterCommands, ConnectionCommands, GenericCommands,
22 GeoCommands, HashCommands, HyperLogLogCommands, InternalPubSubCommands, ListCommands,
23 PubSubCommands, ScriptingCommands, SentinelCommands, ServerCommands, SetCommands,
24 SortedSetCommands, StreamCommands, StringCommands, TransactionCommands,
25 },
26 network::{
27 timeout, JoinHandle, MsgSender, NetworkHandler, PubSubReceiver, PubSubSender, PushReceiver,
28 PushSender, ReconnectReceiver, ReconnectSender, ResultReceiver, ResultSender,
29 ResultsReceiver, ResultsSender,
30 },
31 resp::{cmd, Command, CommandArgs, RespBuf, Response, SingleArg, SingleArgCollection},
32 Error, Future, Result,
33};
34use futures_channel::{mpsc, oneshot};
35use futures_util::Stream;
36use log::{info, trace};
37use serde::de::DeserializeOwned;
38use std::{
39 future::IntoFuture,
40 sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
41 time::Duration,
42};
43
44#[derive(Clone)]
46pub struct Client {
47 msg_sender: Arc<Option<MsgSender>>,
48 network_task_join_handle: Arc<Option<JoinHandle<()>>>,
49 reconnect_sender: ReconnectSender,
50 client_state: Arc<RwLock<ClientState>>,
51 command_timeout: Duration,
52 retry_on_error: bool,
53}
54
55impl Drop for Client {
56 fn drop(&mut self) {
59 let mut network_task_join_handle: Arc<Option<JoinHandle<()>>> = Arc::new(None);
60 std::mem::swap(
61 &mut network_task_join_handle,
62 &mut self.network_task_join_handle,
63 );
64
65 if Arc::try_unwrap(network_task_join_handle).is_ok() {
67 let mut msg_sender: Arc<Option<MsgSender>> = Arc::new(None);
68 std::mem::swap(&mut msg_sender, &mut self.msg_sender);
69
70 if let Ok(Some(msg_sender)) = Arc::try_unwrap(msg_sender) {
71 msg_sender.close_channel();
73 }
74 };
75 }
76}
77
78impl Client {
79 #[inline]
84 pub async fn connect(config: impl IntoConfig) -> Result<Self> {
85 let config = config.into_config()?;
86 let command_timeout = config.command_timeout;
87 let retry_on_error = config.retry_on_error;
88 let (msg_sender, network_task_join_handle, reconnect_sender) =
89 NetworkHandler::connect(config.into_config()?).await?;
90
91 Ok(Self {
92 msg_sender: Arc::new(Some(msg_sender)),
93 network_task_join_handle: Arc::new(Some(network_task_join_handle)),
94 reconnect_sender,
95 client_state: Arc::new(RwLock::new(ClientState::new())),
96 command_timeout,
97 retry_on_error,
98 })
99 }
100
101 pub async fn close(mut self) -> Result<()> {
106 let mut network_task_join_handle: Arc<Option<JoinHandle<()>>> = Arc::new(None);
107 std::mem::swap(
108 &mut network_task_join_handle,
109 &mut self.network_task_join_handle,
110 );
111
112 if let Ok(Some(network_task_join_handle)) = Arc::try_unwrap(network_task_join_handle) {
114 let mut msg_sender: Arc<Option<MsgSender>> = Arc::new(None);
115 std::mem::swap(&mut msg_sender, &mut self.msg_sender);
116
117 if let Ok(Some(msg_sender)) = Arc::try_unwrap(msg_sender) {
118 msg_sender.close_channel();
120 network_task_join_handle.await?;
121 }
122 };
123
124 Ok(())
125 }
126
127 pub fn on_reconnect(&self) -> ReconnectReceiver {
132 self.reconnect_sender.subscribe()
133 }
134
135 pub fn get_client_state(&self) -> RwLockReadGuard<ClientState> {
137 self.client_state.read().unwrap()
138 }
139
140 pub fn get_client_state_mut(&self) -> RwLockWriteGuard<ClientState> {
142 self.client_state.write().unwrap()
143 }
144
145 #[inline]
200 pub async fn send(&self, command: Command, retry_on_error: Option<bool>) -> Result<RespBuf> {
201 let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
202 let message = Message::single(
203 command,
204 result_sender,
205 retry_on_error.unwrap_or(self.retry_on_error),
206 );
207 self.send_message(message)?;
208
209 if self.command_timeout != Duration::ZERO {
210 timeout(self.command_timeout, result_receiver).await??
211 } else {
212 result_receiver.await?
213 }
214 }
215
216 #[inline]
228 pub fn send_and_forget(&self, command: Command, retry_on_error: Option<bool>) -> Result<()> {
229 let message =
230 Message::single_forget(command, retry_on_error.unwrap_or(self.retry_on_error));
231 self.send_message(message)?;
232 Ok(())
233 }
234
235 #[inline]
247 pub async fn send_batch(
248 &self,
249 commands: Vec<Command>,
250 retry_on_error: Option<bool>,
251 ) -> Result<Vec<RespBuf>> {
252 let (results_sender, results_receiver): (ResultsSender, ResultsReceiver) =
253 oneshot::channel();
254 let message = Message::batch(
255 commands,
256 results_sender,
257 retry_on_error.unwrap_or(self.retry_on_error),
258 );
259 self.send_message(message)?;
260
261 if self.command_timeout != Duration::ZERO {
262 timeout(self.command_timeout, results_receiver).await??
263 } else {
264 results_receiver.await?
265 }
266 }
267
268 #[inline]
269 fn send_message(&self, message: Message) -> Result<()> {
270 if let Some(msg_sender) = &self.msg_sender as &Option<MsgSender> {
271 trace!("Will enqueue message: {message:?}");
272 Ok(msg_sender.unbounded_send(message).map_err(|e| {
273 info!("{}", e.to_string());
274 Error::Client("Disconnected from server".to_string())
275 })?)
276 } else {
277 Err(Error::Client(
278 "Invalid channel to send messages to the network handler".to_owned(),
279 ))
280 }
281 }
282
283 #[inline]
285 pub fn create_transaction(&self) -> Transaction {
286 Transaction::new(self.clone())
287 }
288
289 #[inline]
291 pub fn create_pipeline(&self) -> Pipeline {
292 Pipeline::new(self)
293 }
294
295 #[inline]
297 pub fn create_pub_sub(&self) -> PubSubStream {
298 let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
299 PubSubStream::new(pub_sub_sender, pub_sub_receiver, self.clone())
300 }
301
302 pub fn create_client_tracking_invalidation_stream(
303 &self,
304 ) -> Result<impl Stream<Item = Vec<String>>> {
305 let (push_sender, push_receiver): (PushSender, PushReceiver) = mpsc::unbounded();
306 let message = Message::client_tracking_invalidation(push_sender);
307 self.send_message(message)?;
308 Ok(ClientTrackingInvalidationStream::new(push_receiver))
309 }
310
311 pub(crate) async fn subscribe_from_pub_sub_sender(
312 &self,
313 channels: &CommandArgs,
314 pub_sub_sender: &PubSubSender,
315 ) -> Result<()> {
316 let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
317
318 let pub_sub_senders = channels
319 .into_iter()
320 .map(|c| (c.to_vec(), pub_sub_sender.clone()))
321 .collect::<Vec<_>>();
322
323 let message = Message::pub_sub(
324 cmd("SUBSCRIBE").arg(channels.clone()),
325 result_sender,
326 pub_sub_senders,
327 );
328
329 self.send_message(message)?;
330
331 result_receiver.await??.to::<()>()
332 }
333
334 pub(crate) async fn psubscribe_from_pub_sub_sender(
335 &self,
336 patterns: &CommandArgs,
337 pub_sub_sender: &PubSubSender,
338 ) -> Result<()> {
339 let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
340
341 let pub_sub_senders = patterns
342 .into_iter()
343 .map(|c| (c.to_vec(), pub_sub_sender.clone()))
344 .collect::<Vec<_>>();
345
346 let message = Message::pub_sub(
347 cmd("PSUBSCRIBE").arg(patterns.clone()),
348 result_sender,
349 pub_sub_senders,
350 );
351
352 self.send_message(message)?;
353
354 result_receiver.await??.to::<()>()
355 }
356
357 pub(crate) async fn ssubscribe_from_pub_sub_sender(
358 &self,
359 shardchannels: &CommandArgs,
360 pub_sub_sender: &PubSubSender,
361 ) -> Result<()> {
362 let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
363
364 let pub_sub_senders = shardchannels
365 .into_iter()
366 .map(|c| (c.to_vec(), pub_sub_sender.clone()))
367 .collect::<Vec<_>>();
368
369 let message = Message::pub_sub(
370 cmd("SSUBSCRIBE").arg(shardchannels.clone()),
371 result_sender,
372 pub_sub_senders,
373 );
374
375 self.send_message(message)?;
376
377 result_receiver.await??.to::<()>()
378 }
379}
380
381pub trait ClientPreparedCommand<'a, R> {
384 fn forget(self) -> Result<()>;
389}
390
391impl<'a, R: Response> ClientPreparedCommand<'a, R> for PreparedCommand<'a, &'a Client, R> {
392 fn forget(self) -> Result<()> {
397 self.executor
398 .send_and_forget(self.command, self.retry_on_error)
399 }
400}
401
402impl<'a, R> IntoFuture for PreparedCommand<'a, &'a Client, R>
403where
404 R: DeserializeOwned + Send + 'a,
405{
406 type Output = Result<R>;
407 type IntoFuture = Future<'a, R>;
408
409 fn into_future(self) -> Self::IntoFuture {
410 Box::pin(async move {
411 if let Some(custom_converter) = self.custom_converter {
412 let command_for_result = self.command.clone();
413 let result = self
414 .executor
415 .send(self.command, self.retry_on_error)
416 .await?;
417 custom_converter(result, command_for_result, self.executor).await
418 } else {
419 let result = self
420 .executor
421 .send(self.command, self.retry_on_error)
422 .await?;
423 result.to()
424 }
425 })
426 }
427}
428
429impl<'a> BitmapCommands<'a> for &'a Client {}
430#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
431#[cfg(feature = "redis-bloom")]
432impl<'a> BloomCommands<'a> for &'a Client {}
433impl<'a> ClusterCommands<'a> for &'a Client {}
434#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
435#[cfg(feature = "redis-bloom")]
436impl<'a> CountMinSketchCommands<'a> for &'a Client {}
437#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
438#[cfg(feature = "redis-bloom")]
439impl<'a> CuckooCommands<'a> for &'a Client {}
440impl<'a> ConnectionCommands<'a> for &'a Client {}
441#[cfg(test)]
442impl<'a> DebugCommands<'a> for &'a Client {}
443impl<'a> GenericCommands<'a> for &'a Client {}
444impl<'a> GeoCommands<'a> for &'a Client {}
445#[cfg_attr(docsrs, doc(cfg(feature = "redis-graph")))]
446#[cfg(feature = "redis-graph")]
447impl<'a> GraphCommands<'a> for &'a Client {}
448impl<'a> HashCommands<'a> for &'a Client {}
449impl<'a> HyperLogLogCommands<'a> for &'a Client {}
450impl<'a> InternalPubSubCommands<'a> for &'a Client {}
451#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))]
452#[cfg(feature = "redis-json")]
453impl<'a> JsonCommands<'a> for &'a Client {}
454impl<'a> ListCommands<'a> for &'a Client {}
455impl<'a> ScriptingCommands<'a> for &'a Client {}
456#[cfg_attr(docsrs, doc(cfg(feature = "redis-search")))]
457#[cfg(feature = "redis-search")]
458impl<'a> SearchCommands<'a> for &'a Client {}
459impl<'a> SentinelCommands<'a> for &'a Client {}
460impl<'a> ServerCommands<'a> for &'a Client {}
461impl<'a> SetCommands<'a> for &'a Client {}
462impl<'a> SortedSetCommands<'a> for &'a Client {}
463impl<'a> StreamCommands<'a> for &'a Client {}
464impl<'a> StringCommands<'a> for &'a Client {}
465#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
466#[cfg(feature = "redis-bloom")]
467impl<'a> TDigestCommands<'a> for &'a Client {}
468#[cfg_attr(docsrs, doc(cfg(feature = "redis-time-series")))]
469#[cfg(feature = "redis-time-series")]
470impl<'a> TimeSeriesCommands<'a> for &'a Client {}
471impl<'a> TransactionCommands<'a> for &'a Client {}
472#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
473#[cfg(feature = "redis-bloom")]
474impl<'a> TopKCommands<'a> for &'a Client {}
475
476impl<'a> PubSubCommands<'a> for &'a Client {
477 #[inline]
478 async fn subscribe<C, CC>(self, channels: CC) -> Result<PubSubStream>
479 where
480 C: SingleArg + Send + 'a,
481 CC: SingleArgCollection<C>,
482 {
483 let channels = CommandArgs::default().arg(channels).build();
484
485 let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
486
487 self.subscribe_from_pub_sub_sender(&channels, &pub_sub_sender)
488 .await?;
489
490 Ok(PubSubStream::from_channels(
491 channels,
492 pub_sub_sender,
493 pub_sub_receiver,
494 self.clone(),
495 ))
496 }
497
498 #[inline]
499 async fn psubscribe<P, PP>(self, patterns: PP) -> Result<PubSubStream>
500 where
501 P: SingleArg + Send + 'a,
502 PP: SingleArgCollection<P>,
503 {
504 let patterns = CommandArgs::default().arg(patterns).build();
505
506 let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
507
508 self.psubscribe_from_pub_sub_sender(&patterns, &pub_sub_sender)
509 .await?;
510
511 Ok(PubSubStream::from_patterns(
512 patterns,
513 pub_sub_sender,
514 pub_sub_receiver,
515 self.clone(),
516 ))
517 }
518
519 #[inline]
520 async fn ssubscribe<C, CC>(self, shardchannels: CC) -> Result<PubSubStream>
521 where
522 C: SingleArg + Send + 'a,
523 CC: SingleArgCollection<C>,
524 {
525 let shardchannels = CommandArgs::default().arg(shardchannels).build();
526
527 let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
528
529 self.ssubscribe_from_pub_sub_sender(&shardchannels, &pub_sub_sender)
530 .await?;
531
532 Ok(PubSubStream::from_shardchannels(
533 shardchannels,
534 pub_sub_sender,
535 pub_sub_receiver,
536 self.clone(),
537 ))
538 }
539}
540
541impl<'a> BlockingCommands<'a> for &'a Client {
542 async fn monitor(self) -> Result<MonitorStream> {
543 let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
544 let (push_sender, push_receiver): (PushSender, PushReceiver) = mpsc::unbounded();
545
546 let message = Message::monitor(cmd("MONITOR"), result_sender, push_sender);
547
548 self.send_message(message)?;
549
550 let _bytes = result_receiver.await??;
551 Ok(MonitorStream::new(push_receiver, self.clone()))
552 }
553}