Skip to content

Commit f105e72

Browse files
authored
Merge pull request RustPython#1944 from RustPython/coolreader18/_thread-_local
Add _thread._local
2 parents 595f68b + fa2f5d6 commit f105e72

File tree

8 files changed

+379
-6
lines changed

8 files changed

+379
-6
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Lib/_threading_local.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@
126126
127127
affects what we see:
128128
129-
>>> mydata.number
129+
>>> # TODO: RUSTPYTHON, __slots__
130+
>>> mydata.number #doctest: +SKIP
130131
11
131132
132133
>>> del mydata

Lib/test/test_threadedtempfile.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Create and delete FILES_PER_THREAD temp files (via tempfile.TemporaryFile)
3+
in each of NUM_THREADS threads, recording the number of successes and
4+
failures. A failure is a bug in tempfile, and may be due to:
5+
6+
+ Trying to create more than one tempfile with the same name.
7+
+ Trying to delete a tempfile that doesn't still exist.
8+
+ Something we've never seen before.
9+
10+
By default, NUM_THREADS == 20 and FILES_PER_THREAD == 50. This is enough to
11+
create about 150 failures per run under Win98SE in 2.0, and runs pretty
12+
quickly. Guido reports needing to boost FILES_PER_THREAD to 500 before
13+
provoking a 2.0 failure under Linux.
14+
"""
15+
16+
import tempfile
17+
18+
from test.support import start_threads
19+
import unittest
20+
import io
21+
import threading
22+
from traceback import print_exc
23+
24+
25+
NUM_THREADS = 20
26+
FILES_PER_THREAD = 50
27+
28+
29+
startEvent = threading.Event()
30+
31+
32+
class TempFileGreedy(threading.Thread):
33+
error_count = 0
34+
ok_count = 0
35+
36+
def run(self):
37+
self.errors = io.StringIO()
38+
startEvent.wait()
39+
for i in range(FILES_PER_THREAD):
40+
try:
41+
f = tempfile.TemporaryFile("w+b")
42+
f.close()
43+
except:
44+
self.error_count += 1
45+
print_exc(file=self.errors)
46+
else:
47+
self.ok_count += 1
48+
49+
50+
class ThreadedTempFileTest(unittest.TestCase):
51+
def test_main(self):
52+
threads = [TempFileGreedy() for i in range(NUM_THREADS)]
53+
with start_threads(threads, startEvent.set):
54+
pass
55+
ok = sum(t.ok_count for t in threads)
56+
errors = [str(t.name) + str(t.errors.getvalue())
57+
for t in threads if t.error_count]
58+
59+
msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok,
60+
'\n'.join(errors))
61+
self.assertEqual(errors, [], msg)
62+
self.assertEqual(ok, NUM_THREADS * FILES_PER_THREAD)
63+
64+
if __name__ == "__main__":
65+
unittest.main()

Lib/test/test_threading_local.py

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import sys
2+
import unittest
3+
from doctest import DocTestSuite
4+
from test import support
5+
import weakref
6+
# import gc
7+
8+
# Modules under test
9+
import _thread
10+
import threading
11+
import _threading_local
12+
13+
14+
class Weak(object):
15+
pass
16+
17+
def target(local, weaklist):
18+
weak = Weak()
19+
local.weak = weak
20+
weaklist.append(weakref.ref(weak))
21+
22+
23+
class BaseLocalTest:
24+
25+
def test_local_refs(self):
26+
self._local_refs(20)
27+
self._local_refs(50)
28+
self._local_refs(100)
29+
30+
def _local_refs(self, n):
31+
local = self._local()
32+
weaklist = []
33+
for i in range(n):
34+
t = threading.Thread(target=target, args=(local, weaklist))
35+
t.start()
36+
t.join()
37+
del t
38+
39+
# gc.collect()
40+
self.assertEqual(len(weaklist), n)
41+
42+
# XXX _threading_local keeps the local of the last stopped thread alive.
43+
deadlist = [weak for weak in weaklist if weak() is None]
44+
self.assertIn(len(deadlist), (n-1, n))
45+
46+
# Assignment to the same thread local frees it sometimes (!)
47+
local.someothervar = None
48+
# gc.collect()
49+
deadlist = [weak for weak in weaklist if weak() is None]
50+
self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
51+
52+
def test_derived(self):
53+
# Issue 3088: if there is a threads switch inside the __init__
54+
# of a threading.local derived class, the per-thread dictionary
55+
# is created but not correctly set on the object.
56+
# The first member set may be bogus.
57+
import time
58+
class Local(self._local):
59+
def __init__(self):
60+
time.sleep(0.01)
61+
local = Local()
62+
63+
def f(i):
64+
local.x = i
65+
# Simply check that the variable is correctly set
66+
self.assertEqual(local.x, i)
67+
68+
with support.start_threads(threading.Thread(target=f, args=(i,))
69+
for i in range(10)):
70+
pass
71+
72+
def test_derived_cycle_dealloc(self):
73+
# http://bugs.python.org/issue6990
74+
class Local(self._local):
75+
pass
76+
locals = None
77+
passed = False
78+
e1 = threading.Event()
79+
e2 = threading.Event()
80+
81+
def f():
82+
nonlocal passed
83+
# 1) Involve Local in a cycle
84+
cycle = [Local()]
85+
cycle.append(cycle)
86+
cycle[0].foo = 'bar'
87+
88+
# 2) GC the cycle (triggers threadmodule.c::local_clear
89+
# before local_dealloc)
90+
del cycle
91+
# gc.collect()
92+
e1.set()
93+
e2.wait()
94+
95+
# 4) New Locals should be empty
96+
passed = all(not hasattr(local, 'foo') for local in locals)
97+
98+
t = threading.Thread(target=f)
99+
t.start()
100+
e1.wait()
101+
102+
# 3) New Locals should recycle the original's address. Creating
103+
# them in the thread overwrites the thread state and avoids the
104+
# bug
105+
locals = [Local() for i in range(10)]
106+
e2.set()
107+
t.join()
108+
109+
self.assertTrue(passed)
110+
111+
# TODO: RUSTPYTHON, __new__ vs __init__ cooperation
112+
@unittest.expectedFailure
113+
def test_arguments(self):
114+
# Issue 1522237
115+
class MyLocal(self._local):
116+
def __init__(self, *args, **kwargs):
117+
pass
118+
119+
MyLocal(a=1)
120+
MyLocal(1)
121+
self.assertRaises(TypeError, self._local, a=1)
122+
self.assertRaises(TypeError, self._local, 1)
123+
124+
def _test_one_class(self, c):
125+
self._failed = "No error message set or cleared."
126+
obj = c()
127+
e1 = threading.Event()
128+
e2 = threading.Event()
129+
130+
def f1():
131+
obj.x = 'foo'
132+
obj.y = 'bar'
133+
del obj.y
134+
e1.set()
135+
e2.wait()
136+
137+
def f2():
138+
try:
139+
foo = obj.x
140+
except AttributeError:
141+
# This is expected -- we haven't set obj.x in this thread yet!
142+
self._failed = "" # passed
143+
else:
144+
self._failed = ('Incorrectly got value %r from class %r\n' %
145+
(foo, c))
146+
sys.stderr.write(self._failed)
147+
148+
t1 = threading.Thread(target=f1)
149+
t1.start()
150+
e1.wait()
151+
t2 = threading.Thread(target=f2)
152+
t2.start()
153+
t2.join()
154+
# The test is done; just let t1 know it can exit, and wait for it.
155+
e2.set()
156+
t1.join()
157+
158+
self.assertFalse(self._failed, self._failed)
159+
160+
def test_threading_local(self):
161+
self._test_one_class(self._local)
162+
163+
def test_threading_local_subclass(self):
164+
class LocalSubclass(self._local):
165+
"""To test that subclasses behave properly."""
166+
self._test_one_class(LocalSubclass)
167+
168+
def _test_dict_attribute(self, cls):
169+
obj = cls()
170+
obj.x = 5
171+
self.assertEqual(obj.__dict__, {'x': 5})
172+
with self.assertRaises(AttributeError):
173+
obj.__dict__ = {}
174+
with self.assertRaises(AttributeError):
175+
del obj.__dict__
176+
177+
def test_dict_attribute(self):
178+
self._test_dict_attribute(self._local)
179+
180+
def test_dict_attribute_subclass(self):
181+
class LocalSubclass(self._local):
182+
"""To test that subclasses behave properly."""
183+
self._test_dict_attribute(LocalSubclass)
184+
185+
# TODO: RUSTPYTHON, cycle detection/collection
186+
@unittest.expectedFailure
187+
def test_cycle_collection(self):
188+
class X:
189+
pass
190+
191+
x = X()
192+
x.local = self._local()
193+
x.local.x = x
194+
wr = weakref.ref(x)
195+
del x
196+
# gc.collect()
197+
self.assertIsNone(wr())
198+
199+
200+
class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
201+
_local = _thread._local
202+
203+
class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
204+
_local = _threading_local.local
205+
206+
207+
def test_main():
208+
suite = unittest.TestSuite()
209+
suite.addTest(DocTestSuite('_threading_local'))
210+
suite.addTest(unittest.makeSuite(ThreadLocalTest))
211+
# suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
212+
213+
local_orig = _threading_local.local
214+
def setUp(test):
215+
_threading_local.local = _thread._local
216+
def tearDown(test):
217+
_threading_local.local = local_orig
218+
suite.addTest(DocTestSuite('_threading_local',
219+
setUp=setUp, tearDown=tearDown)
220+
)
221+
222+
support.run_unittest(suite)
223+
224+
if __name__ == '__main__':
225+
test_main()

vm/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ bstr = "0.2.12"
7272
crossbeam-utils = "0.7"
7373
generational-arena = "0.2"
7474
parking_lot = { git = "https://github.com/Amanieu/parking_lot" } # TODO: use published version
75+
thread_local = "1.0"
7576

7677
## unicode stuff
7778
unicode_names2 = "0.4"

vm/src/obj/objmodule.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@ impl PyModuleRef {
6767
vm.generic_getattribute_opt(
6868
self.as_object().clone(),
6969
PyString::from("__name__").into_ref(vm),
70+
None,
7071
)
7172
.unwrap_or(None)
7273
.and_then(|obj| obj.payload::<PyString>().map(|s| s.as_str().to_owned()))
7374
}
7475

7576
#[pymethod(magic)]
7677
fn getattribute(self, name: PyStringRef, vm: &VirtualMachine) -> PyResult {
77-
vm.generic_getattribute_opt(self.as_object().clone(), name.clone())?
78+
vm.generic_getattribute_opt(self.as_object().clone(), name.clone(), None)?
7879
.ok_or_else(|| {
7980
let module_name = if let Some(name) = self.name(vm) {
8081
format!(" '{}'", name)

0 commit comments

Comments
 (0)