diff --git a/Cargo.lock b/Cargo.lock index 02cd179ca0..ca47eaad3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1663,6 +1663,7 @@ dependencies = [ "socket2", "statrs", "subprocess", + "thread_local", "uname", "unic-bidi", "unic-char-property", diff --git a/Lib/_threading_local.py b/Lib/_threading_local.py index 76f10229d2..e520433998 100644 --- a/Lib/_threading_local.py +++ b/Lib/_threading_local.py @@ -126,7 +126,8 @@ affects what we see: - >>> mydata.number + >>> # TODO: RUSTPYTHON, __slots__ + >>> mydata.number #doctest: +SKIP 11 >>> del mydata diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py new file mode 100644 index 0000000000..e1d7a10179 --- /dev/null +++ b/Lib/test/test_threadedtempfile.py @@ -0,0 +1,65 @@ +""" +Create and delete FILES_PER_THREAD temp files (via tempfile.TemporaryFile) +in each of NUM_THREADS threads, recording the number of successes and +failures. A failure is a bug in tempfile, and may be due to: + ++ Trying to create more than one tempfile with the same name. ++ Trying to delete a tempfile that doesn't still exist. ++ Something we've never seen before. + +By default, NUM_THREADS == 20 and FILES_PER_THREAD == 50. This is enough to +create about 150 failures per run under Win98SE in 2.0, and runs pretty +quickly. Guido reports needing to boost FILES_PER_THREAD to 500 before +provoking a 2.0 failure under Linux. +""" + +import tempfile + +from test.support import start_threads +import unittest +import io +import threading +from traceback import print_exc + + +NUM_THREADS = 20 +FILES_PER_THREAD = 50 + + +startEvent = threading.Event() + + +class TempFileGreedy(threading.Thread): + error_count = 0 + ok_count = 0 + + def run(self): + self.errors = io.StringIO() + startEvent.wait() + for i in range(FILES_PER_THREAD): + try: + f = tempfile.TemporaryFile("w+b") + f.close() + except: + self.error_count += 1 + print_exc(file=self.errors) + else: + self.ok_count += 1 + + +class ThreadedTempFileTest(unittest.TestCase): + def test_main(self): + threads = [TempFileGreedy() for i in range(NUM_THREADS)] + with start_threads(threads, startEvent.set): + pass + ok = sum(t.ok_count for t in threads) + errors = [str(t.name) + str(t.errors.getvalue()) + for t in threads if t.error_count] + + msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok, + '\n'.join(errors)) + self.assertEqual(errors, [], msg) + self.assertEqual(ok, NUM_THREADS * FILES_PER_THREAD) + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py new file mode 100644 index 0000000000..a2ab266fba --- /dev/null +++ b/Lib/test/test_threading_local.py @@ -0,0 +1,225 @@ +import sys +import unittest +from doctest import DocTestSuite +from test import support +import weakref +# import gc + +# Modules under test +import _thread +import threading +import _threading_local + + +class Weak(object): + pass + +def target(local, weaklist): + weak = Weak() + local.weak = weak + weaklist.append(weakref.ref(weak)) + + +class BaseLocalTest: + + def test_local_refs(self): + self._local_refs(20) + self._local_refs(50) + self._local_refs(100) + + def _local_refs(self, n): + local = self._local() + weaklist = [] + for i in range(n): + t = threading.Thread(target=target, args=(local, weaklist)) + t.start() + t.join() + del t + + # gc.collect() + self.assertEqual(len(weaklist), n) + + # XXX _threading_local keeps the local of the last stopped thread alive. + deadlist = [weak for weak in weaklist if weak() is None] + self.assertIn(len(deadlist), (n-1, n)) + + # Assignment to the same thread local frees it sometimes (!) + local.someothervar = None + # gc.collect() + deadlist = [weak for weak in weaklist if weak() is None] + self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist))) + + def test_derived(self): + # Issue 3088: if there is a threads switch inside the __init__ + # of a threading.local derived class, the per-thread dictionary + # is created but not correctly set on the object. + # The first member set may be bogus. + import time + class Local(self._local): + def __init__(self): + time.sleep(0.01) + local = Local() + + def f(i): + local.x = i + # Simply check that the variable is correctly set + self.assertEqual(local.x, i) + + with support.start_threads(threading.Thread(target=f, args=(i,)) + for i in range(10)): + pass + + def test_derived_cycle_dealloc(self): + # http://bugs.python.org/issue6990 + class Local(self._local): + pass + locals = None + passed = False + e1 = threading.Event() + e2 = threading.Event() + + def f(): + nonlocal passed + # 1) Involve Local in a cycle + cycle = [Local()] + cycle.append(cycle) + cycle[0].foo = 'bar' + + # 2) GC the cycle (triggers threadmodule.c::local_clear + # before local_dealloc) + del cycle + # gc.collect() + e1.set() + e2.wait() + + # 4) New Locals should be empty + passed = all(not hasattr(local, 'foo') for local in locals) + + t = threading.Thread(target=f) + t.start() + e1.wait() + + # 3) New Locals should recycle the original's address. Creating + # them in the thread overwrites the thread state and avoids the + # bug + locals = [Local() for i in range(10)] + e2.set() + t.join() + + self.assertTrue(passed) + + # TODO: RUSTPYTHON, __new__ vs __init__ cooperation + @unittest.expectedFailure + def test_arguments(self): + # Issue 1522237 + class MyLocal(self._local): + def __init__(self, *args, **kwargs): + pass + + MyLocal(a=1) + MyLocal(1) + self.assertRaises(TypeError, self._local, a=1) + self.assertRaises(TypeError, self._local, 1) + + def _test_one_class(self, c): + self._failed = "No error message set or cleared." + obj = c() + e1 = threading.Event() + e2 = threading.Event() + + def f1(): + obj.x = 'foo' + obj.y = 'bar' + del obj.y + e1.set() + e2.wait() + + def f2(): + try: + foo = obj.x + except AttributeError: + # This is expected -- we haven't set obj.x in this thread yet! + self._failed = "" # passed + else: + self._failed = ('Incorrectly got value %r from class %r\n' % + (foo, c)) + sys.stderr.write(self._failed) + + t1 = threading.Thread(target=f1) + t1.start() + e1.wait() + t2 = threading.Thread(target=f2) + t2.start() + t2.join() + # The test is done; just let t1 know it can exit, and wait for it. + e2.set() + t1.join() + + self.assertFalse(self._failed, self._failed) + + def test_threading_local(self): + self._test_one_class(self._local) + + def test_threading_local_subclass(self): + class LocalSubclass(self._local): + """To test that subclasses behave properly.""" + self._test_one_class(LocalSubclass) + + def _test_dict_attribute(self, cls): + obj = cls() + obj.x = 5 + self.assertEqual(obj.__dict__, {'x': 5}) + with self.assertRaises(AttributeError): + obj.__dict__ = {} + with self.assertRaises(AttributeError): + del obj.__dict__ + + def test_dict_attribute(self): + self._test_dict_attribute(self._local) + + def test_dict_attribute_subclass(self): + class LocalSubclass(self._local): + """To test that subclasses behave properly.""" + self._test_dict_attribute(LocalSubclass) + + # TODO: RUSTPYTHON, cycle detection/collection + @unittest.expectedFailure + def test_cycle_collection(self): + class X: + pass + + x = X() + x.local = self._local() + x.local.x = x + wr = weakref.ref(x) + del x + # gc.collect() + self.assertIsNone(wr()) + + +class ThreadLocalTest(unittest.TestCase, BaseLocalTest): + _local = _thread._local + +class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest): + _local = _threading_local.local + + +def test_main(): + suite = unittest.TestSuite() + suite.addTest(DocTestSuite('_threading_local')) + suite.addTest(unittest.makeSuite(ThreadLocalTest)) + # suite.addTest(unittest.makeSuite(PyThreadingLocalTest)) + + local_orig = _threading_local.local + def setUp(test): + _threading_local.local = _thread._local + def tearDown(test): + _threading_local.local = local_orig + suite.addTest(DocTestSuite('_threading_local', + setUp=setUp, tearDown=tearDown) + ) + + support.run_unittest(suite) + +if __name__ == '__main__': + test_main() diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 28b0cc435d..84cb4871ee 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -72,6 +72,7 @@ bstr = "0.2.12" crossbeam-utils = "0.7" generational-arena = "0.2" parking_lot = { git = "https://github.com/Amanieu/parking_lot" } # TODO: use published version +thread_local = "1.0" ## unicode stuff unicode_names2 = "0.4" diff --git a/vm/src/obj/objmodule.rs b/vm/src/obj/objmodule.rs index c68d16e61b..cd43681ddf 100644 --- a/vm/src/obj/objmodule.rs +++ b/vm/src/obj/objmodule.rs @@ -67,6 +67,7 @@ impl PyModuleRef { vm.generic_getattribute_opt( self.as_object().clone(), PyString::from("__name__").into_ref(vm), + None, ) .unwrap_or(None) .and_then(|obj| obj.payload::().map(|s| s.as_str().to_owned())) @@ -74,7 +75,7 @@ impl PyModuleRef { #[pymethod(magic)] fn getattribute(self, name: PyStringRef, vm: &VirtualMachine) -> PyResult { - vm.generic_getattribute_opt(self.as_object().clone(), name.clone())? + vm.generic_getattribute_opt(self.as_object().clone(), name.clone(), None)? .ok_or_else(|| { let module_name = if let Some(name) = self.name(vm) { format!(" '{}'", name) diff --git a/vm/src/stdlib/thread.rs b/vm/src/stdlib/thread.rs index 1a3ec9a94d..ca1b71fdfa 100644 --- a/vm/src/stdlib/thread.rs +++ b/vm/src/stdlib/thread.rs @@ -2,11 +2,12 @@ use crate::exceptions; use crate::function::{Args, KwArgs, OptionalArg, PyFuncArgs}; use crate::obj::objdict::PyDictRef; +use crate::obj::objstr::PyStringRef; use crate::obj::objtuple::PyTupleRef; use crate::obj::objtype::PyClassRef; use crate::pyobject::{ - Either, IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, - TypeProtocol, + Either, IdProtocol, ItemProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, + PyValue, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -14,6 +15,8 @@ use parking_lot::{ lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, RawMutex, RawThreadId, }; +use thread_local::ThreadLocal; + use std::cell::RefCell; use std::io::Write; use std::time::Duration; @@ -264,12 +267,85 @@ fn thread_count(vm: &VirtualMachine) -> usize { vm.state.thread_count.load() } +#[pyclass(name = "_local")] +#[derive(Debug)] +struct PyLocal { + data: ThreadLocal, +} + +impl PyValue for PyLocal { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_thread", "_local") + } +} + +#[pyimpl(flags(BASETYPE))] +impl PyLocal { + fn ldict(&self, vm: &VirtualMachine) -> PyDictRef { + self.data.get_or(|| vm.ctx.new_dict()).clone() + } + + #[pyslot] + fn tp_new(cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { + PyLocal { + data: ThreadLocal::new(), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(magic)] + fn getattribute(zelf: PyRef, attr: PyStringRef, vm: &VirtualMachine) -> PyResult { + let ldict = zelf.ldict(vm); + if attr.as_str() == "__dict__" { + Ok(ldict.into_object()) + } else { + let zelf = zelf.into_object(); + vm.generic_getattribute_opt(zelf.clone(), attr.clone(), Some(ldict))? + .ok_or_else(|| { + vm.new_attribute_error(format!("{} has no attribute '{}'", zelf, attr)) + }) + } + } + + #[pymethod(magic)] + fn setattr( + zelf: PyRef, + attr: PyStringRef, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + if attr.as_str() == "__dict__" { + Err(vm.new_attribute_error(format!( + "{} attribute '__dict__' is read-only", + zelf.as_object() + ))) + } else { + zelf.ldict(vm).set_item(attr.as_object(), value, vm)?; + Ok(()) + } + } + + #[pymethod(magic)] + fn delattr(zelf: PyRef, attr: PyStringRef, vm: &VirtualMachine) -> PyResult<()> { + if attr.as_str() == "__dict__" { + Err(vm.new_attribute_error(format!( + "{} attribute '__dict__' is read-only", + zelf.as_object() + ))) + } else { + zelf.ldict(vm).del_item(attr.as_object(), vm)?; + Ok(()) + } + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; py_module!(vm, "_thread", { "RLock" => PyRLock::make_class(ctx), "LockType" => PyLock::make_class(ctx), + "_local" => PyLocal::make_class(ctx), "get_ident" => ctx.new_function(thread_get_ident), "allocate_lock" => ctx.new_function(thread_allocate_lock), "start_new_thread" => ctx.new_function(thread_start_new_thread), diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 9bf5e2f700..488a474eb1 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -1001,7 +1001,7 @@ impl VirtualMachine { } pub fn generic_getattribute(&self, obj: PyObjectRef, name: PyStringRef) -> PyResult { - self.generic_getattribute_opt(obj.clone(), name.clone())? + self.generic_getattribute_opt(obj.clone(), name.clone(), None)? .ok_or_else(|| self.new_attribute_error(format!("{} has no attribute '{}'", obj, name))) } @@ -1010,6 +1010,7 @@ impl VirtualMachine { &self, obj: PyObjectRef, name_str: PyStringRef, + dict: Option, ) -> PyResult> { let name = name_str.as_str(); let cls = obj.class(); @@ -1023,7 +1024,9 @@ impl VirtualMachine { } } - let attr = if let Some(dict) = obj.dict() { + let dict = dict.or_else(|| obj.dict()); + + let attr = if let Some(dict) = dict { dict.get_item_option(name_str.as_str(), self)? } else { None