Skip to content

Commit c8ce3bd

Browse files
authored
Merge pull request #1187 from RustPython/dict-keying
Implement dictionary indexing by trait.
2 parents 50dd93d + 42dca44 commit c8ce3bd

File tree

6 files changed

+127
-43
lines changed

6 files changed

+127
-43
lines changed

vm/src/dictdatatype.rs

Lines changed: 98 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use crate::obj::objbool;
2+
use crate::obj::objstr::PyString;
23
use crate::pyhash;
34
use crate::pyobject::{IdProtocol, PyObjectRef, PyResult};
45
use crate::vm::VirtualMachine;
6+
use num_bigint::ToBigInt;
57
/// Ordered dictionary implementation.
68
/// Inspired by: https://morepypy.blogspot.com/2015/01/faster-more-memory-efficient-and-more.html
79
/// And: https://www.youtube.com/watch?v=p33CVV29OG8
@@ -93,7 +95,7 @@ impl<T: Clone> Dict<T> {
9395
}
9496
}
9597

96-
pub fn contains(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<bool> {
98+
pub fn contains<K: DictKey>(&self, vm: &VirtualMachine, key: &K) -> PyResult<bool> {
9799
if let LookupResult::Existing(_) = self.lookup(vm, key)? {
98100
Ok(true)
99101
} else {
@@ -111,7 +113,7 @@ impl<T: Clone> Dict<T> {
111113

112114
/// Retrieve a key
113115
#[cfg_attr(feature = "flame-it", flame("Dict"))]
114-
pub fn get(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<Option<T>> {
116+
pub fn get<K: DictKey>(&self, vm: &VirtualMachine, key: &K) -> PyResult<Option<T>> {
115117
if let LookupResult::Existing(index) = self.lookup(vm, key)? {
116118
Ok(Some(self.unchecked_get(index)))
117119
} else {
@@ -149,7 +151,7 @@ impl<T: Clone> Dict<T> {
149151
key: &PyObjectRef,
150152
value: T,
151153
) -> PyResult<()> {
152-
match self.lookup(vm, &key)? {
154+
match self.lookup(vm, key)? {
153155
LookupResult::Existing(entry_index) => self.unchecked_delete(entry_index),
154156
LookupResult::NewIndex {
155157
hash_value,
@@ -199,8 +201,8 @@ impl<T: Clone> Dict<T> {
199201

200202
/// Lookup the index for the given key.
201203
#[cfg_attr(feature = "flame-it", flame("Dict"))]
202-
fn lookup(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<LookupResult> {
203-
let hash_value = collection_hash(vm, key)?;
204+
fn lookup<K: DictKey>(&self, vm: &VirtualMachine, key: &K) -> PyResult<LookupResult> {
205+
let hash_value = key.do_hash(vm)?;
204206
let perturb = hash_value;
205207
let mut hash_index: HashIndex = hash_value;
206208
loop {
@@ -209,11 +211,11 @@ impl<T: Clone> Dict<T> {
209211
let index = self.indices[&hash_index];
210212
if let Some(entry) = &self.entries[index] {
211213
// Okay, we have an entry at this place
212-
if entry.key.is(key) {
214+
if key.do_is(&entry.key) {
213215
// Literally the same object
214216
break Ok(LookupResult::Existing(index));
215217
} else if entry.hash == hash_value {
216-
if do_eq(vm, &entry.key, key)? {
218+
if key.do_eq(vm, &entry.key)? {
217219
break Ok(LookupResult::Existing(index));
218220
} else {
219221
// entry mismatch.
@@ -242,7 +244,7 @@ impl<T: Clone> Dict<T> {
242244
}
243245

244246
/// Retrieve and delete a key
245-
pub fn pop(&mut self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<Option<T>> {
247+
pub fn pop<K: DictKey>(&mut self, vm: &VirtualMachine, key: &K) -> PyResult<Option<T>> {
246248
if let LookupResult::Existing(index) = self.lookup(vm, key)? {
247249
let value = self.unchecked_get(index);
248250
self.unchecked_delete(index);
@@ -273,23 +275,68 @@ enum LookupResult {
273275
Existing(EntryIndex), // Existing record, index into entries
274276
}
275277

276-
#[cfg_attr(feature = "flame-it", flame())]
277-
fn collection_hash(vm: &VirtualMachine, object: &PyObjectRef) -> PyResult<HashValue> {
278-
let raw_hash = vm._hash(object)?;
279-
let mut hasher = DefaultHasher::new();
280-
raw_hash.hash(&mut hasher);
281-
Ok(hasher.finish() as HashValue)
278+
/// Types implementing this trait can be used to index
279+
/// the dictionary. Typical usecases are:
280+
/// - PyObjectRef -> arbitrary python type used as key
281+
/// - str -> string reference used as key, this is often used internally
282+
pub trait DictKey {
283+
fn do_hash(&self, vm: &VirtualMachine) -> PyResult<HashValue>;
284+
fn do_is(&self, other: &PyObjectRef) -> bool;
285+
fn do_eq(&self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult<bool>;
282286
}
283287

284-
/// Invoke __eq__ on two keys
285-
fn do_eq(vm: &VirtualMachine, key1: &PyObjectRef, key2: &PyObjectRef) -> Result<bool, PyObjectRef> {
286-
let result = vm._eq(key1.clone(), key2.clone())?;
287-
objbool::boolval(vm, result)
288+
/// Implement trait for PyObjectRef such that we can use python objects
289+
/// to index dictionaries.
290+
impl DictKey for PyObjectRef {
291+
fn do_hash(&self, vm: &VirtualMachine) -> PyResult<HashValue> {
292+
let raw_hash = vm._hash(self)?;
293+
let mut hasher = DefaultHasher::new();
294+
raw_hash.hash(&mut hasher);
295+
Ok(hasher.finish() as HashValue)
296+
}
297+
298+
fn do_is(&self, other: &PyObjectRef) -> bool {
299+
self.is(other)
300+
}
301+
302+
fn do_eq(&self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult<bool> {
303+
let result = vm._eq(self.clone(), other_key.clone())?;
304+
objbool::boolval(vm, result)
305+
}
306+
}
307+
308+
/// Implement trait for the str type, so that we can use strings
309+
/// to index dictionaries.
310+
impl DictKey for String {
311+
fn do_hash(&self, _vm: &VirtualMachine) -> PyResult<HashValue> {
312+
// follow a similar route as the hashing of PyStringRef
313+
let raw_hash = pyhash::hash_value(self).to_bigint().unwrap();
314+
let raw_hash = pyhash::hash_bigint(&raw_hash);
315+
let mut hasher = DefaultHasher::new();
316+
raw_hash.hash(&mut hasher);
317+
Ok(hasher.finish() as HashValue)
318+
}
319+
320+
fn do_is(&self, _other: &PyObjectRef) -> bool {
321+
// No matter who the other pyobject is, we are never the same thing, since
322+
// we are a str, not a pyobject.
323+
false
324+
}
325+
326+
fn do_eq(&self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult<bool> {
327+
if let Some(py_str_value) = other_key.payload::<PyString>() {
328+
Ok(&py_str_value.value == self)
329+
} else {
330+
// Fall back to PyString implementation.
331+
let s = vm.new_str(self.to_string());
332+
s.do_eq(vm, other_key)
333+
}
334+
}
288335
}
289336

290337
#[cfg(test)]
291338
mod tests {
292-
use super::{Dict, VirtualMachine};
339+
use super::{Dict, DictKey, VirtualMachine};
293340

294341
#[test]
295342
fn test_insert() {
@@ -313,9 +360,40 @@ mod tests {
313360
dict.delete(&vm, &key1).unwrap();
314361
assert_eq!(1, dict.len());
315362

316-
dict.insert(&vm, &key1, value2).unwrap();
363+
dict.insert(&vm, &key1, value2.clone()).unwrap();
317364
assert_eq!(2, dict.len());
318365

319366
assert_eq!(true, dict.contains(&vm, &key1).unwrap());
367+
assert_eq!(true, dict.contains(&vm, &"x".to_string()).unwrap());
368+
369+
let val = dict.get(&vm, &"x".to_string()).unwrap().unwrap();
370+
vm._eq(val, value2)
371+
.expect("retrieved value must be equal to inserted value.");
372+
}
373+
374+
macro_rules! hash_tests {
375+
($($name:ident: $example_hash:expr,)*) => {
376+
$(
377+
#[test]
378+
fn $name() {
379+
check_hash_equivalence($example_hash);
380+
}
381+
)*
382+
}
383+
}
384+
385+
hash_tests! {
386+
test_abc: "abc",
387+
test_x: "x",
388+
}
389+
390+
fn check_hash_equivalence(text: &str) {
391+
let vm: VirtualMachine = Default::default();
392+
let value1 = text.to_string();
393+
let value2 = vm.new_str(value1.clone());
394+
395+
let hash1 = value1.do_hash(&vm).expect("Hash should not fail.");
396+
let hash2 = value2.do_hash(&vm).expect("Hash should not fail.");
397+
assert_eq!(hash1, hash2);
320398
}
321399
}

vm/src/obj/objdict.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::vm::{ReprGuard, VirtualMachine};
1111
use super::objbool;
1212
use super::objiter;
1313
use super::objstr;
14+
use super::objtype;
1415
use crate::dictdatatype;
1516
use crate::obj::objtype::PyClassRef;
1617
use crate::pyobject::PyClassImpl;
@@ -316,6 +317,23 @@ impl PyDictRef {
316317
pub fn size(&self) -> dictdatatype::DictSize {
317318
self.entries.borrow().size()
318319
}
320+
321+
pub fn get_item_option<T: IntoPyObject>(
322+
&self,
323+
key: T,
324+
vm: &VirtualMachine,
325+
) -> PyResult<Option<PyObjectRef>> {
326+
match self.get_item(key, vm) {
327+
Ok(value) => Ok(Some(value)),
328+
Err(exc) => {
329+
if objtype::isinstance(&exc, &vm.ctx.exceptions.key_error) {
330+
Ok(None)
331+
} else {
332+
Err(exc)
333+
}
334+
}
335+
}
336+
}
319337
}
320338

321339
impl ItemProtocol for PyDictRef {

vm/src/obj/objint.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,7 @@ impl PyInt {
418418

419419
#[pymethod(name = "__hash__")]
420420
pub fn hash(&self, _vm: &VirtualMachine) -> pyhash::PyHash {
421-
match self.value.to_i64() {
422-
Some(value) => (value % pyhash::MODULUS as i64),
423-
None => (&self.value % pyhash::MODULUS).to_i64().unwrap(),
424-
}
421+
pyhash::hash_bigint(&self.value)
425422
}
426423

427424
#[pymethod(name = "__abs__")]

vm/src/obj/objsuper.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::obj::objfunction::PyMethod;
1111
use crate::obj::objstr;
1212
use crate::obj::objtype::{PyClass, PyClassRef};
1313
use crate::pyobject::{
14-
ItemProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
14+
PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
1515
};
1616
use crate::scope::NameProtocol;
1717
use crate::vm::VirtualMachine;

vm/src/pyhash.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use num_bigint::BigInt;
2+
use num_traits::ToPrimitive;
13
use std::hash::{Hash, Hasher};
24

35
use crate::obj::objfloat;
@@ -81,3 +83,10 @@ pub fn hash_iter<'a, I: std::iter::Iterator<Item = &'a PyObjectRef>>(
8183
}
8284
Ok(hasher.finish() as PyHash)
8385
}
86+
87+
pub fn hash_bigint(value: &BigInt) -> PyHash {
88+
match value.to_i64() {
89+
Some(i64_value) => (i64_value % MODULUS as i64),
90+
None => (value % MODULUS).to_i64().unwrap(),
91+
}
92+
}

vm/src/pyobject.rs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,24 +1044,6 @@ pub trait ItemProtocol {
10441044
vm: &VirtualMachine,
10451045
) -> PyResult;
10461046
fn del_item<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) -> PyResult;
1047-
1048-
#[cfg_attr(feature = "flame-it", flame("ItemProtocol"))]
1049-
fn get_item_option<T: IntoPyObject>(
1050-
&self,
1051-
key: T,
1052-
vm: &VirtualMachine,
1053-
) -> PyResult<Option<PyObjectRef>> {
1054-
match self.get_item(key, vm) {
1055-
Ok(value) => Ok(Some(value)),
1056-
Err(exc) => {
1057-
if objtype::isinstance(&exc, &vm.ctx.exceptions.key_error) {
1058-
Ok(None)
1059-
} else {
1060-
Err(exc)
1061-
}
1062-
}
1063-
}
1064-
}
10651047
}
10661048

10671049
impl ItemProtocol for PyObjectRef {

0 commit comments

Comments
 (0)