Skip to content

Add _thread._local #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Lib/_threading_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@

affects what we see:

>>> mydata.number
>>> # TODO: RUSTPYTHON, __slots__
>>> mydata.number #doctest: +SKIP
11

>>> del mydata
Expand Down
65 changes: 65 additions & 0 deletions Lib/test/test_threadedtempfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Create and delete FILES_PER_THREAD temp files (via tempfile.TemporaryFile)
in each of NUM_THREADS threads, recording the number of successes and
failures. A failure is a bug in tempfile, and may be due to:

+ Trying to create more than one tempfile with the same name.
+ Trying to delete a tempfile that doesn't still exist.
+ Something we've never seen before.

By default, NUM_THREADS == 20 and FILES_PER_THREAD == 50. This is enough to
create about 150 failures per run under Win98SE in 2.0, and runs pretty
quickly. Guido reports needing to boost FILES_PER_THREAD to 500 before
provoking a 2.0 failure under Linux.
"""

import tempfile

from test.support import start_threads
import unittest
import io
import threading
from traceback import print_exc


NUM_THREADS = 20
FILES_PER_THREAD = 50


startEvent = threading.Event()


class TempFileGreedy(threading.Thread):
error_count = 0
ok_count = 0

def run(self):
self.errors = io.StringIO()
startEvent.wait()
for i in range(FILES_PER_THREAD):
try:
f = tempfile.TemporaryFile("w+b")
f.close()
except:
self.error_count += 1
print_exc(file=self.errors)
else:
self.ok_count += 1


class ThreadedTempFileTest(unittest.TestCase):
def test_main(self):
threads = [TempFileGreedy() for i in range(NUM_THREADS)]
with start_threads(threads, startEvent.set):
pass
ok = sum(t.ok_count for t in threads)
errors = [str(t.name) + str(t.errors.getvalue())
for t in threads if t.error_count]

msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok,
'\n'.join(errors))
self.assertEqual(errors, [], msg)
self.assertEqual(ok, NUM_THREADS * FILES_PER_THREAD)

if __name__ == "__main__":
unittest.main()
225 changes: 225 additions & 0 deletions Lib/test/test_threading_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import sys
import unittest
from doctest import DocTestSuite
from test import support
import weakref
# import gc

# Modules under test
import _thread
import threading
import _threading_local


class Weak(object):
pass

def target(local, weaklist):
weak = Weak()
local.weak = weak
weaklist.append(weakref.ref(weak))


class BaseLocalTest:

def test_local_refs(self):
self._local_refs(20)
self._local_refs(50)
self._local_refs(100)

def _local_refs(self, n):
local = self._local()
weaklist = []
for i in range(n):
t = threading.Thread(target=target, args=(local, weaklist))
t.start()
t.join()
del t

# gc.collect()
self.assertEqual(len(weaklist), n)

# XXX _threading_local keeps the local of the last stopped thread alive.
deadlist = [weak for weak in weaklist if weak() is None]
self.assertIn(len(deadlist), (n-1, n))

# Assignment to the same thread local frees it sometimes (!)
local.someothervar = None
# gc.collect()
deadlist = [weak for weak in weaklist if weak() is None]
self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))

def test_derived(self):
# Issue 3088: if there is a threads switch inside the __init__
# of a threading.local derived class, the per-thread dictionary
# is created but not correctly set on the object.
# The first member set may be bogus.
import time
class Local(self._local):
def __init__(self):
time.sleep(0.01)
local = Local()

def f(i):
local.x = i
# Simply check that the variable is correctly set
self.assertEqual(local.x, i)

with support.start_threads(threading.Thread(target=f, args=(i,))
for i in range(10)):
pass

def test_derived_cycle_dealloc(self):
# http://bugs.python.org/issue6990
class Local(self._local):
pass
locals = None
passed = False
e1 = threading.Event()
e2 = threading.Event()

def f():
nonlocal passed
# 1) Involve Local in a cycle
cycle = [Local()]
cycle.append(cycle)
cycle[0].foo = 'bar'

# 2) GC the cycle (triggers threadmodule.c::local_clear
# before local_dealloc)
del cycle
# gc.collect()
e1.set()
e2.wait()

# 4) New Locals should be empty
passed = all(not hasattr(local, 'foo') for local in locals)

t = threading.Thread(target=f)
t.start()
e1.wait()

# 3) New Locals should recycle the original's address. Creating
# them in the thread overwrites the thread state and avoids the
# bug
locals = [Local() for i in range(10)]
e2.set()
t.join()

self.assertTrue(passed)

# TODO: RUSTPYTHON, __new__ vs __init__ cooperation
@unittest.expectedFailure
def test_arguments(self):
# Issue 1522237
class MyLocal(self._local):
def __init__(self, *args, **kwargs):
pass

MyLocal(a=1)
MyLocal(1)
self.assertRaises(TypeError, self._local, a=1)
self.assertRaises(TypeError, self._local, 1)

def _test_one_class(self, c):
self._failed = "No error message set or cleared."
obj = c()
e1 = threading.Event()
e2 = threading.Event()

def f1():
obj.x = 'foo'
obj.y = 'bar'
del obj.y
e1.set()
e2.wait()

def f2():
try:
foo = obj.x
except AttributeError:
# This is expected -- we haven't set obj.x in this thread yet!
self._failed = "" # passed
else:
self._failed = ('Incorrectly got value %r from class %r\n' %
(foo, c))
sys.stderr.write(self._failed)

t1 = threading.Thread(target=f1)
t1.start()
e1.wait()
t2 = threading.Thread(target=f2)
t2.start()
t2.join()
# The test is done; just let t1 know it can exit, and wait for it.
e2.set()
t1.join()

self.assertFalse(self._failed, self._failed)

def test_threading_local(self):
self._test_one_class(self._local)

def test_threading_local_subclass(self):
class LocalSubclass(self._local):
"""To test that subclasses behave properly."""
self._test_one_class(LocalSubclass)

def _test_dict_attribute(self, cls):
obj = cls()
obj.x = 5
self.assertEqual(obj.__dict__, {'x': 5})
with self.assertRaises(AttributeError):
obj.__dict__ = {}
with self.assertRaises(AttributeError):
del obj.__dict__

def test_dict_attribute(self):
self._test_dict_attribute(self._local)

def test_dict_attribute_subclass(self):
class LocalSubclass(self._local):
"""To test that subclasses behave properly."""
self._test_dict_attribute(LocalSubclass)

# TODO: RUSTPYTHON, cycle detection/collection
@unittest.expectedFailure
def test_cycle_collection(self):
class X:
pass

x = X()
x.local = self._local()
x.local.x = x
wr = weakref.ref(x)
del x
# gc.collect()
self.assertIsNone(wr())


class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
_local = _thread._local

class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
_local = _threading_local.local


def test_main():
suite = unittest.TestSuite()
suite.addTest(DocTestSuite('_threading_local'))
suite.addTest(unittest.makeSuite(ThreadLocalTest))
# suite.addTest(unittest.makeSuite(PyThreadingLocalTest))

local_orig = _threading_local.local
def setUp(test):
_threading_local.local = _thread._local
def tearDown(test):
_threading_local.local = local_orig
suite.addTest(DocTestSuite('_threading_local',
setUp=setUp, tearDown=tearDown)
)

support.run_unittest(suite)

if __name__ == '__main__':
test_main()
1 change: 1 addition & 0 deletions vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ bstr = "0.2.12"
crossbeam-utils = "0.7"
generational-arena = "0.2"
parking_lot = { git = "https://github.com/Amanieu/parking_lot" } # TODO: use published version
thread_local = "1.0"

## unicode stuff
unicode_names2 = "0.4"
Expand Down
3 changes: 2 additions & 1 deletion vm/src/obj/objmodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ impl PyModuleRef {
vm.generic_getattribute_opt(
self.as_object().clone(),
PyString::from("__name__").into_ref(vm),
None,
)
.unwrap_or(None)
.and_then(|obj| obj.payload::<PyString>().map(|s| s.as_str().to_owned()))
}

#[pymethod(magic)]
fn getattribute(self, name: PyStringRef, vm: &VirtualMachine) -> PyResult {
vm.generic_getattribute_opt(self.as_object().clone(), name.clone())?
vm.generic_getattribute_opt(self.as_object().clone(), name.clone(), None)?
.ok_or_else(|| {
let module_name = if let Some(name) = self.name(vm) {
format!(" '{}'", name)
Expand Down
Loading