ndarray/
array_serde.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8use serde::de::{self, MapAccess, SeqAccess, Visitor};
9use serde::ser::{SerializeSeq, SerializeStruct};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12use alloc::format;
13#[cfg(not(feature = "std"))]
14use alloc::vec::Vec;
15use std::fmt;
16use std::marker::PhantomData;
17
18use crate::imp_prelude::*;
19
20use super::arraytraits::ARRAY_FORMAT_VERSION;
21use super::Iter;
22use crate::IntoDimension;
23
24/// Verifies that the version of the deserialized array matches the current
25/// `ARRAY_FORMAT_VERSION`.
26pub fn verify_version<E>(v: u8) -> Result<(), E>
27where E: de::Error
28{
29    if v != ARRAY_FORMAT_VERSION {
30        let err_msg = format!("unknown array version: {}", v);
31        Err(de::Error::custom(err_msg))
32    } else {
33        Ok(())
34    }
35}
36
37/// **Requires crate feature `"serde"`**
38impl<I> Serialize for Dim<I>
39where I: Serialize
40{
41    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
42    where Se: Serializer
43    {
44        self.ix().serialize(serializer)
45    }
46}
47
48/// **Requires crate feature `"serde"`**
49impl<'de, I> Deserialize<'de> for Dim<I>
50where I: Deserialize<'de>
51{
52    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
53    where D: Deserializer<'de>
54    {
55        I::deserialize(deserializer).map(Dim::new)
56    }
57}
58
59/// **Requires crate feature `"serde"`**
60impl Serialize for IxDyn
61{
62    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
63    where Se: Serializer
64    {
65        self.ix().serialize(serializer)
66    }
67}
68
69/// **Requires crate feature `"serde"`**
70impl<'de> Deserialize<'de> for IxDyn
71{
72    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
73    where D: Deserializer<'de>
74    {
75        let v = Vec::<Ix>::deserialize(deserializer)?;
76        Ok(v.into_dimension())
77    }
78}
79
80/// **Requires crate feature `"serde"`**
81impl<A, D, S> Serialize for ArrayBase<S, D>
82where
83    A: Serialize,
84    D: Dimension + Serialize,
85    S: Data<Elem = A>,
86{
87    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
88    where Se: Serializer
89    {
90        let mut state = serializer.serialize_struct("Array", 3)?;
91        state.serialize_field("v", &ARRAY_FORMAT_VERSION)?;
92        state.serialize_field("dim", &self.raw_dim())?;
93        state.serialize_field("data", &Sequence(self.iter()))?;
94        state.end()
95    }
96}
97
98// private iterator wrapper
99struct Sequence<'a, A, D>(Iter<'a, A, D>);
100
101impl<'a, A, D> Serialize for Sequence<'a, A, D>
102where
103    A: Serialize,
104    D: Dimension + Serialize,
105{
106    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107    where S: Serializer
108    {
109        let iter = &self.0;
110        let mut seq = serializer.serialize_seq(Some(iter.len()))?;
111        for elt in iter.clone() {
112            seq.serialize_element(elt)?;
113        }
114        seq.end()
115    }
116}
117
118struct ArrayVisitor<S, Di>
119{
120    _marker_a: PhantomData<S>,
121    _marker_b: PhantomData<Di>,
122}
123
124enum ArrayField
125{
126    Version,
127    Dim,
128    Data,
129}
130
131impl<S, Di> ArrayVisitor<S, Di>
132{
133    pub fn new() -> Self
134    {
135        ArrayVisitor {
136            _marker_a: PhantomData,
137            _marker_b: PhantomData,
138        }
139    }
140}
141
142static ARRAY_FIELDS: &[&str] = &["v", "dim", "data"];
143
144/// **Requires crate feature `"serde"`**
145impl<'de, A, Di, S> Deserialize<'de> for ArrayBase<S, Di>
146where
147    A: Deserialize<'de>,
148    Di: Deserialize<'de> + Dimension,
149    S: DataOwned<Elem = A>,
150{
151    fn deserialize<D>(deserializer: D) -> Result<ArrayBase<S, Di>, D::Error>
152    where D: Deserializer<'de>
153    {
154        deserializer.deserialize_struct("Array", ARRAY_FIELDS, ArrayVisitor::new())
155    }
156}
157
158impl<'de> Deserialize<'de> for ArrayField
159{
160    fn deserialize<D>(deserializer: D) -> Result<ArrayField, D::Error>
161    where D: Deserializer<'de>
162    {
163        struct ArrayFieldVisitor;
164
165        impl<'de> Visitor<'de> for ArrayFieldVisitor
166        {
167            type Value = ArrayField;
168
169            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result
170            {
171                formatter.write_str(r#""v", "dim", or "data""#)
172            }
173
174            fn visit_str<E>(self, value: &str) -> Result<ArrayField, E>
175            where E: de::Error
176            {
177                match value {
178                    "v" => Ok(ArrayField::Version),
179                    "dim" => Ok(ArrayField::Dim),
180                    "data" => Ok(ArrayField::Data),
181                    other => Err(de::Error::unknown_field(other, ARRAY_FIELDS)),
182                }
183            }
184
185            fn visit_bytes<E>(self, value: &[u8]) -> Result<ArrayField, E>
186            where E: de::Error
187            {
188                match value {
189                    b"v" => Ok(ArrayField::Version),
190                    b"dim" => Ok(ArrayField::Dim),
191                    b"data" => Ok(ArrayField::Data),
192                    other => Err(de::Error::unknown_field(&format!("{:?}", other), ARRAY_FIELDS)),
193                }
194            }
195        }
196
197        deserializer.deserialize_identifier(ArrayFieldVisitor)
198    }
199}
200
201impl<'de, A, Di, S> Visitor<'de> for ArrayVisitor<S, Di>
202where
203    A: Deserialize<'de>,
204    Di: Deserialize<'de> + Dimension,
205    S: DataOwned<Elem = A>,
206{
207    type Value = ArrayBase<S, Di>;
208
209    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result
210    {
211        formatter.write_str("ndarray representation")
212    }
213
214    fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayBase<S, Di>, V::Error>
215    where V: SeqAccess<'de>
216    {
217        let v: u8 = match visitor.next_element()? {
218            Some(value) => value,
219            None => {
220                return Err(de::Error::invalid_length(0, &self));
221            }
222        };
223
224        verify_version(v)?;
225
226        let dim: Di = match visitor.next_element()? {
227            Some(value) => value,
228            None => {
229                return Err(de::Error::invalid_length(1, &self));
230            }
231        };
232
233        let data: Vec<A> = match visitor.next_element()? {
234            Some(value) => value,
235            None => {
236                return Err(de::Error::invalid_length(2, &self));
237            }
238        };
239
240        if let Ok(array) = ArrayBase::from_shape_vec(dim, data) {
241            Ok(array)
242        } else {
243            Err(de::Error::custom("data and dimension must match in size"))
244        }
245    }
246
247    fn visit_map<V>(self, mut visitor: V) -> Result<ArrayBase<S, Di>, V::Error>
248    where V: MapAccess<'de>
249    {
250        let mut v: Option<u8> = None;
251        let mut data: Option<Vec<A>> = None;
252        let mut dim: Option<Di> = None;
253
254        while let Some(key) = visitor.next_key()? {
255            match key {
256                ArrayField::Version => {
257                    let val = visitor.next_value()?;
258                    verify_version(val)?;
259                    v = Some(val);
260                }
261                ArrayField::Data => {
262                    data = Some(visitor.next_value()?);
263                }
264                ArrayField::Dim => {
265                    dim = Some(visitor.next_value()?);
266                }
267            }
268        }
269
270        let _v = match v {
271            Some(v) => v,
272            None => return Err(de::Error::missing_field("v")),
273        };
274
275        let data = match data {
276            Some(data) => data,
277            None => return Err(de::Error::missing_field("data")),
278        };
279
280        let dim = match dim {
281            Some(dim) => dim,
282            None => return Err(de::Error::missing_field("dim")),
283        };
284
285        if let Ok(array) = ArrayBase::from_shape_vec(dim, data) {
286            Ok(array)
287        } else {
288            Err(de::Error::custom("data and dimension must match in size"))
289        }
290    }
291}