rustis/client/
transaction.rs

1use serde::{
2    de::{self, DeserializeOwned, DeserializeSeed, IgnoredAny, SeqAccess, Visitor},
3    forward_to_deserialize_any, Deserializer,
4};
5
6#[cfg(feature = "redis-graph")]
7use crate::commands::GraphCommands;
8#[cfg(feature = "redis-json")]
9use crate::commands::JsonCommands;
10#[cfg(feature = "redis-search")]
11use crate::commands::SearchCommands;
12#[cfg(feature = "redis-time-series")]
13use crate::commands::TimeSeriesCommands;
14#[cfg(feature = "redis-bloom")]
15use crate::commands::{
16    BloomCommands, CountMinSketchCommands, CuckooCommands, TDigestCommands, TopKCommands,
17};
18use crate::{
19    client::{BatchPreparedCommand, Client, PreparedCommand},
20    commands::{
21        BitmapCommands, GenericCommands, GeoCommands, HashCommands, HyperLogLogCommands,
22        ListCommands, ScriptingCommands, ServerCommands, SetCommands, SortedSetCommands,
23        StreamCommands, StringCommands,
24    },
25    resp::{cmd, Command, RespDeserializer, Response},
26    Error, Result,
27};
28use std::{fmt, marker::PhantomData};
29
30/// Represents an on-going [`transaction`](https://redis.io/docs/manual/transactions/) on a specific client instance.
31pub struct Transaction {
32    client: Client,
33    commands: Vec<Command>,
34    forget_flags: Vec<bool>,
35    retry_on_error: Option<bool>,
36}
37
38impl Transaction {
39    pub(crate) fn new(client: Client) -> Self {
40        Self {
41            client,
42            commands: vec![cmd("MULTI")],
43            forget_flags: Vec::new(),
44            retry_on_error: None,
45        }
46    }
47
48    /// Set a flag to override default `retry_on_error` behavior.
49    ///
50    /// See [Config::retry_on_error](crate::client::Config::retry_on_error)
51    pub fn retry_on_error(&mut self, retry_on_error: bool) {
52        self.retry_on_error = Some(retry_on_error);
53    }
54
55    /// Queue a command into the transaction.
56    pub fn queue(&mut self, command: Command) {
57        self.commands.push(command);
58        self.forget_flags.push(false);
59    }
60
61    /// Queue a command into the transaction and forget its response.
62    pub fn forget(&mut self, command: Command) {
63        self.commands.push(command);
64        self.forget_flags.push(true);
65    }
66
67    /// Execute the transaction by the sending the queued command
68    /// as a whole batch to the Redis server.
69    ///
70    /// # Return
71    /// It is the caller responsability to use the right type to cast the server response
72    /// to the right tuple or collection depending on which command has been
73    /// [queued](BatchPreparedCommand::queue) or [forgotten](BatchPreparedCommand::forget).
74    ///
75    /// The most generic type that can be requested as a result is `Vec<resp::Value>`
76    ///
77    /// # Example
78    /// ```
79    /// use rustis::{
80    ///     client::{Client, Transaction, BatchPreparedCommand},
81    ///     commands::StringCommands,
82    ///     resp::{cmd, Value}, Result,
83    /// };
84    ///
85    /// #[cfg_attr(feature = "tokio-runtime", tokio::main)]
86    /// #[cfg_attr(feature = "async-std-runtime", async_std::main)]
87    /// async fn main() -> Result<()> {
88    ///     let client = Client::connect("127.0.0.1:6379").await?;
89    ///
90    ///     let mut transaction = client.create_transaction();
91    ///
92    ///     transaction.set("key1", "value1").forget();
93    ///     transaction.set("key2", "value2").forget();
94    ///     transaction.get::<_, String>("key1").queue();
95    ///     let value: String = transaction.execute().await?;
96    ///
97    ///     assert_eq!("value1", value);
98    ///
99    ///     Ok(())
100    /// }
101    /// ```
102    pub async fn execute<T: DeserializeOwned>(mut self) -> Result<T> {
103        self.commands.push(cmd("EXEC"));
104
105        let num_commands = self.commands.len();
106
107        let results = self
108            .client
109            .send_batch(self.commands, self.retry_on_error)
110            .await?;
111
112        let mut iter = results.into_iter();
113
114        // MULTI + QUEUED commands
115        for _ in 0..num_commands - 1 {
116            if let Some(resp_buf) = iter.next() {
117                resp_buf.to::<()>()?;
118            }
119        }
120
121        // EXEC
122        if let Some(result) = iter.next() {
123            let mut deserializer = RespDeserializer::new(&result);
124            match TransactionResultSeed::new(self.forget_flags).deserialize(&mut deserializer) {
125                Ok(Some(t)) => Ok(t),
126                Ok(None) => Err(Error::Aborted),
127                Err(e) => Err(e),
128            }
129        } else {
130            Err(Error::Client(
131                "Unexpected result for transaction".to_owned(),
132            ))
133        }
134    }
135}
136
137struct TransactionResultSeed<T: DeserializeOwned> {
138    phantom: PhantomData<T>,
139    forget_flags: Vec<bool>,
140}
141
142impl<T: DeserializeOwned> TransactionResultSeed<T> {
143    pub fn new(forget_flags: Vec<bool>) -> Self {
144        Self {
145            phantom: PhantomData,
146            forget_flags,
147        }
148    }
149}
150
151impl<'de, T: DeserializeOwned> DeserializeSeed<'de> for TransactionResultSeed<T> {
152    type Value = Option<T>;
153
154    fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
155    where
156        D: serde::Deserializer<'de>,
157    {
158        deserializer.deserialize_any(self)
159    }
160}
161
162impl<'de, T: DeserializeOwned> Visitor<'de> for TransactionResultSeed<T> {
163    type Value = Option<T>;
164
165    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
166        formatter.write_str("Option<T>")
167    }
168
169    fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
170    where
171        A: serde::de::SeqAccess<'de>,
172    {
173        if self
174            .forget_flags
175            .iter()
176            .fold(0, |acc, flag| if *flag { acc } else { acc + 1 })
177            == 1
178        {
179            for forget in &self.forget_flags {
180                if *forget {
181                    seq.next_element::<IgnoredAny>()?;
182                } else {
183                    return seq.next_element::<T>();
184                }
185            }
186            Ok(None)
187        } else {
188            let deserializer = SeqAccessDeserializer {
189                forget_flags: self.forget_flags.into_iter(),
190                seq_access: seq,
191            };
192
193            T::deserialize(deserializer)
194                .map(Some)
195                .map_err(de::Error::custom)
196        }
197    }
198
199    fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
200    where
201        E: serde::de::Error,
202    {
203        Ok(None)
204    }
205}
206
207struct SeqAccessDeserializer<A> {
208    forget_flags: std::vec::IntoIter<bool>,
209    seq_access: A,
210}
211
212impl<'de, A> Deserializer<'de> for SeqAccessDeserializer<A>
213where
214    A: serde::de::SeqAccess<'de>,
215{
216    type Error = Error;
217
218    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
219    where
220        V: Visitor<'de>,
221    {
222        self.deserialize_seq(visitor)
223    }
224
225    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
226    where
227        V: Visitor<'de>,
228    {
229        visitor.visit_seq(self)
230    }
231
232    forward_to_deserialize_any! {
233        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str
234        bytes byte_buf unit_struct newtype_struct string tuple
235        tuple_struct map struct enum identifier ignored_any unit option
236    }
237}
238
239impl<'de, A> SeqAccess<'de> for SeqAccessDeserializer<A>
240where
241    A: serde::de::SeqAccess<'de>,
242{
243    type Error = Error;
244
245    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
246    where
247        T: DeserializeSeed<'de>,
248    {
249        for forget in self.forget_flags.by_ref() {
250            if forget {
251                self.seq_access
252                    .next_element::<IgnoredAny>()
253                    .map_err::<Error, _>(de::Error::custom)?;
254            } else {
255                return self
256                    .seq_access
257                    .next_element_seed(seed)
258                    .map_err(de::Error::custom);
259            }
260        }
261        Ok(None)
262    }
263}
264
265impl<'a, R: Response> BatchPreparedCommand for PreparedCommand<'a, &'a mut Transaction, R> {
266    /// Queue a command into the transaction.
267    fn queue(self) {
268        self.executor.queue(self.command)
269    }
270
271    /// Queue a command into the transaction and forget its response.
272    fn forget(self) {
273        self.executor.forget(self.command)
274    }
275}
276
277impl<'a> BitmapCommands<'a> for &'a mut Transaction {}
278#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
279#[cfg(feature = "redis-bloom")]
280impl<'a> BloomCommands<'a> for &'a mut Transaction {}
281#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
282#[cfg(feature = "redis-bloom")]
283impl<'a> CountMinSketchCommands<'a> for &'a mut Transaction {}
284#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
285#[cfg(feature = "redis-bloom")]
286impl<'a> CuckooCommands<'a> for &'a mut Transaction {}
287impl<'a> GenericCommands<'a> for &'a mut Transaction {}
288impl<'a> GeoCommands<'a> for &'a mut Transaction {}
289#[cfg_attr(docsrs, doc(cfg(feature = "redis-graph")))]
290#[cfg(feature = "redis-graph")]
291impl<'a> GraphCommands<'a> for &'a mut Transaction {}
292impl<'a> HashCommands<'a> for &'a mut Transaction {}
293impl<'a> HyperLogLogCommands<'a> for &'a mut Transaction {}
294#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))]
295#[cfg(feature = "redis-json")]
296impl<'a> JsonCommands<'a> for &'a mut Transaction {}
297impl<'a> ListCommands<'a> for &'a mut Transaction {}
298#[cfg_attr(docsrs, doc(cfg(feature = "redis-search")))]
299#[cfg(feature = "redis-search")]
300impl<'a> SearchCommands<'a> for &'a mut Transaction {}
301impl<'a> SetCommands<'a> for &'a mut Transaction {}
302impl<'a> ScriptingCommands<'a> for &'a mut Transaction {}
303impl<'a> ServerCommands<'a> for &'a mut Transaction {}
304impl<'a> SortedSetCommands<'a> for &'a mut Transaction {}
305impl<'a> StreamCommands<'a> for &'a mut Transaction {}
306impl<'a> StringCommands<'a> for &'a mut Transaction {}
307#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
308#[cfg(feature = "redis-bloom")]
309impl<'a> TDigestCommands<'a> for &'a mut Transaction {}
310#[cfg_attr(docsrs, doc(cfg(feature = "redis-time-series")))]
311#[cfg(feature = "redis-time-series")]
312impl<'a> TimeSeriesCommands<'a> for &'a mut Transaction {}
313#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
314#[cfg(feature = "redis-bloom")]
315impl<'a> TopKCommands<'a> for &'a mut Transaction {}