Skip to content

Commit 866bc02

Browse files
authored
Merge pull request #8071 from gfyoung/randint-tempita
MAINT: Add Tempita to randint helpers
2 parents 55ece58 + 7fdfa6b commit 866bc02

File tree

4 files changed

+124
-12
lines changed

4 files changed

+124
-12
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ matrix:
5252
apt:
5353
packages:
5454
- *common_packages
55+
- cython3-dbg
5556
- python3-dbg
5657
- python3-dev
5758
- python3-nose

numpy/random/mtrand/mtrand.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2323

2424
include "Python.pxi"
25+
include "randint_helpers.pxi"
2526
include "numpy.pxd"
2627
include "cpython/pycapsule.pxd"
2728

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Template for each `dtype` helper function in `np.random.randint`.
3+
"""
4+
5+
{{py:
6+
7+
dtypes = (
8+
('bool', 'bool', 'bool_'),
9+
('int8', 'uint8', 'int8'),
10+
('int16', 'uint16', 'int16'),
11+
('int32', 'uint32', 'int32'),
12+
('int64', 'uint64', 'int64'),
13+
('uint8', 'uint8', 'uint8'),
14+
('uint16', 'uint16', 'uint16'),
15+
('uint32', 'uint32', 'uint32'),
16+
('uint64', 'uint64', 'uint64'),
17+
)
18+
19+
def get_dispatch(dtypes):
20+
for npy_dt, npy_udt, np_dt in dtypes:
21+
yield npy_dt, npy_udt, np_dt
22+
}}
23+
24+
{{for npy_dt, npy_udt, np_dt in get_dispatch(dtypes)}}
25+
26+
def _rand_{{npy_dt}}(low, high, size, rngstate):
27+
"""
28+
_rand_{{npy_dt}}(low, high, size, rngstate)
29+
30+
Return random np.{{np_dt}} integers between ``low`` and ``high``, inclusive.
31+
32+
Return random integers from the "discrete uniform" distribution in the
33+
closed interval [``low``, ``high``). On entry the arguments are presumed
34+
to have been validated for size and order for the np.{{np_dt}} type.
35+
36+
Parameters
37+
----------
38+
low : int
39+
Lowest (signed) integer to be drawn from the distribution.
40+
high : int
41+
Highest (signed) integer to be drawn from the distribution.
42+
size : int or tuple of ints
43+
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
44+
``m * n * k`` samples are drawn. Default is None, in which case a
45+
single value is returned.
46+
rngstate : encapsulated pointer to rk_state
47+
The specific type depends on the python version. In Python 2 it is
48+
a PyCObject, in Python 3 a PyCapsule object.
49+
50+
Returns
51+
-------
52+
out : python integer or ndarray of np.{{np_dt}}
53+
`size`-shaped array of random integers from the appropriate
54+
distribution, or a single such random int if `size` not provided.
55+
56+
"""
57+
cdef npy_{{npy_udt}} off, rng, buf
58+
cdef npy_{{npy_udt}} *out
59+
cdef ndarray array "arrayObject"
60+
cdef npy_intp cnt
61+
cdef rk_state *state = <rk_state *>PyCapsule_GetPointer(rngstate, NULL)
62+
63+
rng = <npy_{{npy_udt}}>(high - low)
64+
off = <npy_{{npy_udt}}>(<npy_{{npy_dt}}>low)
65+
66+
if size is None:
67+
rk_random_{{npy_udt}}(off, rng, 1, &buf, state)
68+
return np.{{np_dt}}(<npy_{{npy_dt}}>buf)
69+
else:
70+
array = <ndarray>np.empty(size, np.{{np_dt}})
71+
cnt = PyArray_SIZE(array)
72+
array_data = <npy_{{npy_udt}} *>PyArray_DATA(array)
73+
with nogil:
74+
rk_random_{{npy_udt}}(off, rng, cnt, array_data, state)
75+
return array
76+
77+
{{endfor}}

tools/cythonize.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,24 @@ def process_tempita_pyx(fromfile, tofile):
100100
f.write(pyxcontent)
101101
process_pyx(pyxfile, tofile)
102102

103+
104+
def process_tempita_pxi(fromfile, tofile):
105+
try:
106+
try:
107+
from Cython import Tempita as tempita
108+
except ImportError:
109+
import tempita
110+
except ImportError:
111+
raise Exception('Building %s requires Tempita: '
112+
'pip install --user Tempita' % VENDOR)
113+
assert fromfile.endswith('.pxi.in')
114+
assert tofile.endswith('.pxi')
115+
with open(fromfile, "r") as f:
116+
tmpl = f.read()
117+
pyxcontent = tempita.sub(tmpl)
118+
with open(tofile, "w") as f:
119+
f.write(pyxcontent)
120+
103121
rules = {
104122
# fromext : function
105123
'.pyx' : process_pyx,
@@ -170,22 +188,37 @@ def process(path, fromfile, tofile, processor_function, hash_db):
170188
def find_process_files(root_dir):
171189
hash_db = load_hashes(HASH_FILE)
172190
for cur_dir, dirs, files in os.walk(root_dir):
191+
# .pxi or .pxi.in files are most likely dependencies for
192+
# .pyx files, so we need to process them first
193+
files.sort(key=lambda name: (name.endswith('.pxi') or
194+
name.endswith('.pxi.in')),
195+
reverse=True)
196+
173197
for filename in files:
174198
in_file = os.path.join(cur_dir, filename + ".in")
175199
if filename.endswith('.pyx') and os.path.isfile(in_file):
176200
continue
177-
for fromext, function in rules.items():
178-
if filename.endswith(fromext):
179-
toext = ".c"
180-
with open(os.path.join(cur_dir, filename), 'rb') as f:
181-
data = f.read()
182-
m = re.search(br"^\s*#\s*distutils:\s*language\s*=\s*c\+\+\s*$", data, re.I|re.M)
183-
if m:
184-
toext = ".cxx"
185-
fromfile = filename
186-
tofile = filename[:-len(fromext)] + toext
187-
process(cur_dir, fromfile, tofile, function, hash_db)
188-
save_hashes(hash_db, HASH_FILE)
201+
elif filename.endswith('.pxi.in'):
202+
toext = '.pxi'
203+
fromext = '.pxi.in'
204+
fromfile = filename
205+
function = process_tempita_pxi
206+
tofile = filename[:-len(fromext)] + toext
207+
process(cur_dir, fromfile, tofile, function, hash_db)
208+
save_hashes(hash_db, HASH_FILE)
209+
else:
210+
for fromext, function in rules.items():
211+
if filename.endswith(fromext):
212+
toext = ".c"
213+
with open(os.path.join(cur_dir, filename), 'rb') as f:
214+
data = f.read()
215+
m = re.search(br"^\s*#\s*distutils:\s*language\s*=\s*c\+\+\s*$", data, re.I|re.M)
216+
if m:
217+
toext = ".cxx"
218+
fromfile = filename
219+
tofile = filename[:-len(fromext)] + toext
220+
process(cur_dir, fromfile, tofile, function, hash_db)
221+
save_hashes(hash_db, HASH_FILE)
189222

190223
def main():
191224
try:

0 commit comments

Comments
 (0)