Skip to content

Commit dc8b4b8

Browse files
committed
Added from_coordinates method to SparseVector
1 parent 8b6102f commit dc8b4b8

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

src/sparsevec.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,23 @@ impl SparseVector {
4848
}
4949
}
5050

51+
/// Creates a sparse vector from coordinates.
52+
pub fn from_coordinates<I: IntoIterator<Item = (i32, f32)>>(
53+
iter: I,
54+
dim: usize,
55+
) -> SparseVector {
56+
let mut elements: Vec<(i32, f32)> = iter.into_iter().collect();
57+
elements.sort_by_key(|v| v.0);
58+
let indices: Vec<i32> = elements.iter().map(|v| v.0).collect();
59+
let values: Vec<f32> = elements.iter().map(|v| v.1).collect();
60+
61+
SparseVector {
62+
dim,
63+
indices,
64+
values,
65+
}
66+
}
67+
5168
/// Returns the sparse vector as a `Vec<f32>`.
5269
pub fn to_vec(&self) -> Vec<f32> {
5370
let mut vec = vec![0.0; self.dim];
@@ -91,13 +108,28 @@ impl SparseVector {
91108
#[cfg(test)]
92109
mod tests {
93110
use crate::SparseVector;
111+
use std::collections::HashMap;
94112

95113
#[test]
96114
fn test_from_dense() {
97115
let vec = SparseVector::from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0]);
98116
assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec());
99117
}
100118

119+
#[test]
120+
fn test_from_coo_map() {
121+
let elements = HashMap::from([(0, 1.0), (2, 2.0), (4, 3.0)]);
122+
let vec = SparseVector::from_coordinates(elements, 6);
123+
assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec());
124+
}
125+
126+
#[test]
127+
fn test_from_coo_vec() {
128+
let elements = vec![(0, 1.0), (2, 2.0), (4, 3.0)];
129+
let vec = SparseVector::from_coordinates(elements, 6);
130+
assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec());
131+
}
132+
101133
#[test]
102134
fn test_to_vec() {
103135
let vec = SparseVector::new(6, vec![0, 2, 4], vec![1.0, 2.0, 3.0]);

0 commit comments

Comments
 (0)