1use serde::{
2 de::{self, DeserializeOwned, DeserializeSeed, IgnoredAny, SeqAccess, Visitor},
3 forward_to_deserialize_any, Deserializer,
4};
5
6#[cfg(feature = "redis-graph")]
7use crate::commands::GraphCommands;
8#[cfg(feature = "redis-json")]
9use crate::commands::JsonCommands;
10#[cfg(feature = "redis-search")]
11use crate::commands::SearchCommands;
12#[cfg(feature = "redis-time-series")]
13use crate::commands::TimeSeriesCommands;
14#[cfg(feature = "redis-bloom")]
15use crate::commands::{
16 BloomCommands, CountMinSketchCommands, CuckooCommands, TDigestCommands, TopKCommands,
17};
18use crate::{
19 client::{BatchPreparedCommand, Client, PreparedCommand},
20 commands::{
21 BitmapCommands, GenericCommands, GeoCommands, HashCommands, HyperLogLogCommands,
22 ListCommands, ScriptingCommands, ServerCommands, SetCommands, SortedSetCommands,
23 StreamCommands, StringCommands,
24 },
25 resp::{cmd, Command, RespDeserializer, Response},
26 Error, Result,
27};
28use std::{fmt, marker::PhantomData};
29
30pub struct Transaction {
32 client: Client,
33 commands: Vec<Command>,
34 forget_flags: Vec<bool>,
35 retry_on_error: Option<bool>,
36}
37
38impl Transaction {
39 pub(crate) fn new(client: Client) -> Self {
40 Self {
41 client,
42 commands: vec![cmd("MULTI")],
43 forget_flags: Vec::new(),
44 retry_on_error: None,
45 }
46 }
47
48 pub fn retry_on_error(&mut self, retry_on_error: bool) {
52 self.retry_on_error = Some(retry_on_error);
53 }
54
55 pub fn queue(&mut self, command: Command) {
57 self.commands.push(command);
58 self.forget_flags.push(false);
59 }
60
61 pub fn forget(&mut self, command: Command) {
63 self.commands.push(command);
64 self.forget_flags.push(true);
65 }
66
67 pub async fn execute<T: DeserializeOwned>(mut self) -> Result<T> {
103 self.commands.push(cmd("EXEC"));
104
105 let num_commands = self.commands.len();
106
107 let results = self
108 .client
109 .send_batch(self.commands, self.retry_on_error)
110 .await?;
111
112 let mut iter = results.into_iter();
113
114 for _ in 0..num_commands - 1 {
116 if let Some(resp_buf) = iter.next() {
117 resp_buf.to::<()>()?;
118 }
119 }
120
121 if let Some(result) = iter.next() {
123 let mut deserializer = RespDeserializer::new(&result);
124 match TransactionResultSeed::new(self.forget_flags).deserialize(&mut deserializer) {
125 Ok(Some(t)) => Ok(t),
126 Ok(None) => Err(Error::Aborted),
127 Err(e) => Err(e),
128 }
129 } else {
130 Err(Error::Client(
131 "Unexpected result for transaction".to_owned(),
132 ))
133 }
134 }
135}
136
137struct TransactionResultSeed<T: DeserializeOwned> {
138 phantom: PhantomData<T>,
139 forget_flags: Vec<bool>,
140}
141
142impl<T: DeserializeOwned> TransactionResultSeed<T> {
143 pub fn new(forget_flags: Vec<bool>) -> Self {
144 Self {
145 phantom: PhantomData,
146 forget_flags,
147 }
148 }
149}
150
151impl<'de, T: DeserializeOwned> DeserializeSeed<'de> for TransactionResultSeed<T> {
152 type Value = Option<T>;
153
154 fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
155 where
156 D: serde::Deserializer<'de>,
157 {
158 deserializer.deserialize_any(self)
159 }
160}
161
162impl<'de, T: DeserializeOwned> Visitor<'de> for TransactionResultSeed<T> {
163 type Value = Option<T>;
164
165 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
166 formatter.write_str("Option<T>")
167 }
168
169 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
170 where
171 A: serde::de::SeqAccess<'de>,
172 {
173 if self
174 .forget_flags
175 .iter()
176 .fold(0, |acc, flag| if *flag { acc } else { acc + 1 })
177 == 1
178 {
179 for forget in &self.forget_flags {
180 if *forget {
181 seq.next_element::<IgnoredAny>()?;
182 } else {
183 return seq.next_element::<T>();
184 }
185 }
186 Ok(None)
187 } else {
188 let deserializer = SeqAccessDeserializer {
189 forget_flags: self.forget_flags.into_iter(),
190 seq_access: seq,
191 };
192
193 T::deserialize(deserializer)
194 .map(Some)
195 .map_err(de::Error::custom)
196 }
197 }
198
199 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
200 where
201 E: serde::de::Error,
202 {
203 Ok(None)
204 }
205}
206
207struct SeqAccessDeserializer<A> {
208 forget_flags: std::vec::IntoIter<bool>,
209 seq_access: A,
210}
211
212impl<'de, A> Deserializer<'de> for SeqAccessDeserializer<A>
213where
214 A: serde::de::SeqAccess<'de>,
215{
216 type Error = Error;
217
218 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
219 where
220 V: Visitor<'de>,
221 {
222 self.deserialize_seq(visitor)
223 }
224
225 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
226 where
227 V: Visitor<'de>,
228 {
229 visitor.visit_seq(self)
230 }
231
232 forward_to_deserialize_any! {
233 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str
234 bytes byte_buf unit_struct newtype_struct string tuple
235 tuple_struct map struct enum identifier ignored_any unit option
236 }
237}
238
239impl<'de, A> SeqAccess<'de> for SeqAccessDeserializer<A>
240where
241 A: serde::de::SeqAccess<'de>,
242{
243 type Error = Error;
244
245 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
246 where
247 T: DeserializeSeed<'de>,
248 {
249 for forget in self.forget_flags.by_ref() {
250 if forget {
251 self.seq_access
252 .next_element::<IgnoredAny>()
253 .map_err::<Error, _>(de::Error::custom)?;
254 } else {
255 return self
256 .seq_access
257 .next_element_seed(seed)
258 .map_err(de::Error::custom);
259 }
260 }
261 Ok(None)
262 }
263}
264
265impl<'a, R: Response> BatchPreparedCommand for PreparedCommand<'a, &'a mut Transaction, R> {
266 fn queue(self) {
268 self.executor.queue(self.command)
269 }
270
271 fn forget(self) {
273 self.executor.forget(self.command)
274 }
275}
276
277impl<'a> BitmapCommands<'a> for &'a mut Transaction {}
278#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
279#[cfg(feature = "redis-bloom")]
280impl<'a> BloomCommands<'a> for &'a mut Transaction {}
281#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
282#[cfg(feature = "redis-bloom")]
283impl<'a> CountMinSketchCommands<'a> for &'a mut Transaction {}
284#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
285#[cfg(feature = "redis-bloom")]
286impl<'a> CuckooCommands<'a> for &'a mut Transaction {}
287impl<'a> GenericCommands<'a> for &'a mut Transaction {}
288impl<'a> GeoCommands<'a> for &'a mut Transaction {}
289#[cfg_attr(docsrs, doc(cfg(feature = "redis-graph")))]
290#[cfg(feature = "redis-graph")]
291impl<'a> GraphCommands<'a> for &'a mut Transaction {}
292impl<'a> HashCommands<'a> for &'a mut Transaction {}
293impl<'a> HyperLogLogCommands<'a> for &'a mut Transaction {}
294#[cfg_attr(docsrs, doc(cfg(feature = "redis-json")))]
295#[cfg(feature = "redis-json")]
296impl<'a> JsonCommands<'a> for &'a mut Transaction {}
297impl<'a> ListCommands<'a> for &'a mut Transaction {}
298#[cfg_attr(docsrs, doc(cfg(feature = "redis-search")))]
299#[cfg(feature = "redis-search")]
300impl<'a> SearchCommands<'a> for &'a mut Transaction {}
301impl<'a> SetCommands<'a> for &'a mut Transaction {}
302impl<'a> ScriptingCommands<'a> for &'a mut Transaction {}
303impl<'a> ServerCommands<'a> for &'a mut Transaction {}
304impl<'a> SortedSetCommands<'a> for &'a mut Transaction {}
305impl<'a> StreamCommands<'a> for &'a mut Transaction {}
306impl<'a> StringCommands<'a> for &'a mut Transaction {}
307#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
308#[cfg(feature = "redis-bloom")]
309impl<'a> TDigestCommands<'a> for &'a mut Transaction {}
310#[cfg_attr(docsrs, doc(cfg(feature = "redis-time-series")))]
311#[cfg(feature = "redis-time-series")]
312impl<'a> TimeSeriesCommands<'a> for &'a mut Transaction {}
313#[cfg_attr(docsrs, doc(cfg(feature = "redis-bloom")))]
314#[cfg(feature = "redis-bloom")]
315impl<'a> TopKCommands<'a> for &'a mut Transaction {}