Skip to content

Use StdRng instead of ThreadRng in _random #1939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 26 additions & 63 deletions vm/src/stdlib/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ mod _random {
use crate::obj::objtype::PyClassRef;
use crate::pyobject::{PyClassImpl, PyRef, PyResult, PyValue};
use crate::VirtualMachine;
use generational_arena::{self, Arena};
use num_bigint::{BigInt, Sign};
use num_traits::Signed;
use rand::RngCore;
use std::cell::RefCell;
use rand::{rngs::StdRng, RngCore, SeedableRng};

use std::sync::Mutex;

#[derive(Debug)]
enum PyRng {
Std(rand::rngs::ThreadRng),
Std(Box<StdRng>),
MT(Box<mt19937::MT19937>),
}

impl Default for PyRng {
fn default() -> Self {
PyRng::Std(rand::thread_rng())
PyRng::Std(Box::new(StdRng::from_entropy()))
}
}

Expand Down Expand Up @@ -54,47 +54,10 @@ mod _random {
}
}

thread_local!(static RNG_HANDLES: RefCell<Arena<PyRng>> = RefCell::new(Arena::new()));

#[derive(Debug)]
struct RngHandle(generational_arena::Index);
impl RngHandle {
fn new(rng: PyRng) -> Self {
let idx = RNG_HANDLES.with(|arena| arena.borrow_mut().insert(rng));
RngHandle(idx)
}
fn exec<F, R>(&self, func: F) -> R
where
F: Fn(&mut PyRng) -> R,
{
RNG_HANDLES.with(|arena| {
func(
arena
.borrow_mut()
.get_mut(self.0)
.expect("index was removed"),
)
})
}
fn replace(&self, rng: PyRng) {
RNG_HANDLES.with(|arena| {
*arena
.borrow_mut()
.get_mut(self.0)
.expect("index was removed") = rng
})
}
}
impl Drop for RngHandle {
fn drop(&mut self) {
RNG_HANDLES.with(|arena| arena.borrow_mut().remove(self.0));
}
}

#[pyclass(name = "Random")]
#[derive(Debug)]
struct PyRandom {
rng: RngHandle,
rng: Mutex<PyRng>,
}

impl PyValue for PyRandom {
Expand All @@ -108,14 +71,15 @@ mod _random {
#[pyslot(new)]
fn new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
PyRandom {
rng: RngHandle::new(PyRng::default()),
rng: Mutex::default(),
}
.into_ref_with_type(vm, cls)
}

#[pymethod]
fn random(&self) -> f64 {
self.rng.exec(mt19937::gen_res53)
let mut rng = self.rng.lock().unwrap();
mt19937::gen_res53(&mut *rng)
}

#[pymethod]
Expand All @@ -131,32 +95,31 @@ mod _random {
}
};

self.rng.replace(new_rng);
*self.rng.lock().unwrap() = new_rng;
}

#[pymethod]
fn getrandbits(&self, k: usize) -> BigInt {
self.rng.exec(|rng| {
let mut k = k;
let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32;
let mut rng = self.rng.lock().unwrap();
let mut k = k;
let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32;

if k <= 32 {
return gen_u32(k).into();
}
if k <= 32 {
return gen_u32(k).into();
}

let words = (k - 1) / 8 + 1;
let mut wordarray = vec![0u32; words];
let words = (k - 1) / 8 + 1;
let mut wordarray = vec![0u32; words];

let it = wordarray.iter_mut();
#[cfg(target_endian = "big")]
let it = it.rev();
for word in it {
*word = gen_u32(k);
k -= 32;
}
let it = wordarray.iter_mut();
#[cfg(target_endian = "big")]
let it = it.rev();
for word in it {
*word = gen_u32(k);
k -= 32;
}

BigInt::from_slice(Sign::NoSign, &wordarray)
})
BigInt::from_slice(Sign::NoSign, &wordarray)
}
}
}