Skip to content

Commit 0cea276

Browse files
committed
Make PyRandom ThreadSafe
1 parent 4ca4709 commit 0cea276

File tree

3 files changed

+65
-22
lines changed

3 files changed

+65
-22
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ num_enum = "0.4"
7070
smallbox = "0.8"
7171
bstr = "0.2.12"
7272
crossbeam-utils = "0.7"
73+
generational-arena = "0.2"
7374

7475
## unicode stuff
7576
unicode_names2 = "0.4"

vm/src/stdlib/random.rs

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ mod _random {
77
use crate::function::OptionalOption;
88
use crate::obj::objint::PyIntRef;
99
use crate::obj::objtype::PyClassRef;
10-
use crate::pyobject::{PyClassImpl, PyRef, PyResult, PyValue};
10+
use crate::pyobject::{PyClassImpl, PyRef, PyResult, PyValue, ThreadSafe};
1111
use crate::VirtualMachine;
12+
use generational_arena::{self, Arena};
1213
use num_bigint::{BigInt, Sign};
1314
use num_traits::Signed;
1415
use rand::RngCore;
@@ -53,12 +54,51 @@ mod _random {
5354
}
5455
}
5556

57+
thread_local!(static RNG_HANDLES: RefCell<Arena<PyRng>> = RefCell::new(Arena::new()));
58+
59+
#[derive(Debug)]
60+
struct RngHandle(generational_arena::Index);
61+
impl RngHandle {
62+
fn new(rng: PyRng) -> Self {
63+
let idx = RNG_HANDLES.with(|arena| arena.borrow_mut().insert(rng));
64+
RngHandle(idx)
65+
}
66+
fn exec<F, R>(&self, func: F) -> R
67+
where
68+
F: Fn(&mut PyRng) -> R,
69+
{
70+
RNG_HANDLES.with(|arena| {
71+
func(
72+
arena
73+
.borrow_mut()
74+
.get_mut(self.0)
75+
.expect("index was removed"),
76+
)
77+
})
78+
}
79+
fn replace(&self, rng: PyRng) {
80+
RNG_HANDLES.with(|arena| {
81+
*arena
82+
.borrow_mut()
83+
.get_mut(self.0)
84+
.expect("index was removed") = rng
85+
})
86+
}
87+
}
88+
impl Drop for RngHandle {
89+
fn drop(&mut self) {
90+
RNG_HANDLES.with(|arena| arena.borrow_mut().remove(self.0));
91+
}
92+
}
93+
5694
#[pyclass(name = "Random")]
5795
#[derive(Debug)]
5896
struct PyRandom {
59-
rng: RefCell<PyRng>,
97+
rng: RngHandle,
6098
}
6199

100+
impl ThreadSafe for PyRandom {}
101+
62102
impl PyValue for PyRandom {
63103
fn class(vm: &VirtualMachine) -> PyClassRef {
64104
vm.class("_random", "Random")
@@ -70,14 +110,14 @@ mod _random {
70110
#[pyslot(new)]
71111
fn new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
72112
PyRandom {
73-
rng: RefCell::new(PyRng::default()),
113+
rng: RngHandle::new(PyRng::default()),
74114
}
75115
.into_ref_with_type(vm, cls)
76116
}
77117

78118
#[pymethod]
79119
fn random(&self) -> f64 {
80-
mt19937::gen_res53(&mut *self.rng.borrow_mut())
120+
self.rng.exec(mt19937::gen_res53)
81121
}
82122

83123
#[pymethod]
@@ -93,31 +133,32 @@ mod _random {
93133
}
94134
};
95135

96-
*self.rng.borrow_mut() = new_rng;
136+
self.rng.replace(new_rng);
97137
}
98138

99139
#[pymethod]
100-
fn getrandbits(&self, mut k: usize) -> BigInt {
101-
let mut rng = self.rng.borrow_mut();
140+
fn getrandbits(&self, k: usize) -> BigInt {
141+
self.rng.exec(|rng| {
142+
let mut k = k;
143+
let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32;
102144

103-
let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32;
104-
105-
if k <= 32 {
106-
return gen_u32(k).into();
107-
}
145+
if k <= 32 {
146+
return gen_u32(k).into();
147+
}
108148

109-
let words = (k - 1) / 8 + 1;
110-
let mut wordarray = vec![0u32; words];
149+
let words = (k - 1) / 8 + 1;
150+
let mut wordarray = vec![0u32; words];
111151

112-
let it = wordarray.iter_mut();
113-
#[cfg(target_endian = "big")]
114-
let it = it.rev();
115-
for word in it {
116-
*word = gen_u32(k);
117-
k -= 32;
118-
}
152+
let it = wordarray.iter_mut();
153+
#[cfg(target_endian = "big")]
154+
let it = it.rev();
155+
for word in it {
156+
*word = gen_u32(k);
157+
k -= 32;
158+
}
119159

120-
BigInt::from_slice(Sign::NoSign, &wordarray)
160+
BigInt::from_slice(Sign::NoSign, &wordarray)
161+
})
121162
}
122163
}
123164
}

0 commit comments

Comments
 (0)