Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 5ff3b16

Browse files
Adds zip and dict builtins overloads to support easy literal dict ctor (#992)
* Adds zip and dict builtins overloads to support easy literal dict ctor Motivation: there's no easy way to create Numba LiteralStrKeyDict objects for const dicts with many elements. This adds a special overload for dict builtin that creates LiteralStrKeyDict from tuple of pairs ('col_name', col_data). * Replacing zip overload builtin with internal sdc_tuple_zip function Details: zip builtin is already overloaded in Numba and has priority over user defined overloads, hence in cases when we want zip two single elements tuples, e.g. zip(('A', ), (1, )) builtin function will match and type inference will unliteral all tuples, producing iter objects (that are always homogeneous in Numba). That is, literality of objects will be lost. Using sdc_zip_tuples explicitly avoid this problem. * Fixing issue with literal dict ctor with single element * Fixing refcnt issue and adding tests * Adding rewrite for dict(zip()) calls
1 parent d07a81a commit 5ff3b16

File tree

5 files changed

+220
-2
lines changed

5 files changed

+220
-2
lines changed

sdc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171
import sdc.rewrites.dataframe_constructor
7272
import sdc.rewrites.read_csv_consts
73+
import sdc.rewrites.dict_zip_tuples
7374
import sdc.rewrites.dataframe_getitem_attribute
7475
import sdc.datatypes.hpat_pandas_functions
7576
import sdc.datatypes.hpat_pandas_dataframe_functions

sdc/functions/tuple_utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@
2525
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
# *****************************************************************************
2727

28+
from textwrap import dedent
29+
2830
from numba import types
29-
from numba.extending import (intrinsic, )
31+
from numba.extending import intrinsic
3032
from numba.core.typing.templates import (signature, )
33+
from numba.typed.dictobject import build_map
34+
35+
from sdc.utilities.utils import sdc_overload
3136

3237

3338
@intrinsic
@@ -205,3 +210,84 @@ def codegen(context, builder, sig, args):
205210
return context.make_tuple(builder, ret_type, [first_tup, second_tup])
206211

207212
return ret_type(data_type), codegen
213+
214+
215+
def sdc_tuple_zip(x, y):
216+
pass
217+
218+
219+
@sdc_overload(sdc_tuple_zip)
220+
def sdc_tuple_zip_ovld(x, y):
221+
""" This function combines tuple of pairs from two input tuples x and y, preserving
222+
literality of elements in them. """
223+
224+
if not (isinstance(x, types.BaseAnonymousTuple) and isinstance(y, types.BaseAnonymousTuple)):
225+
return None
226+
227+
res_size = min(len(x), len(y))
228+
func_impl_name = 'sdc_tuple_zip_impl'
229+
tup_elements = ', '.join([f"(x[{i}], y[{i}])" for i in range(res_size)])
230+
func_text = dedent(f"""
231+
def {func_impl_name}(x, y):
232+
return ({tup_elements}{',' if res_size else ''})
233+
""")
234+
use_globals, use_locals = {}, {}
235+
exec(func_text, use_globals, use_locals)
236+
return use_locals[func_impl_name]
237+
238+
# FIXME_Numba#6533: alternatively we could have used sdc_tuple_map_elementwise
239+
# to avoid another use of exec, but due to @intrinsic-s not supporting
240+
# prefer_literal option below implementation looses literaly of args!
241+
# from sdc.functions.tuple_utils import sdc_tuple_map_elementwise
242+
# def sdc_tuple_zip_impl(x, y):
243+
# return sdc_tuple_map_elementwise(
244+
# lambda a, b: (a, b),
245+
# x,
246+
# y
247+
# )
248+
#
249+
# return sdc_tuple_zip_impl
250+
251+
252+
@intrinsic
253+
def literal_dict_ctor(typingctx, items):
254+
255+
tup_size = len(items)
256+
key_order = {p[0].literal_value: i for i, p in enumerate(items)}
257+
ret_type = types.LiteralStrKeyDict(dict(items), key_order)
258+
259+
def codegen(context, builder, sig, args):
260+
items_val = args[0]
261+
262+
# extract elements from the input tuple and repack into a list of variables required by build_map
263+
repacked_items = []
264+
for i in range(tup_size):
265+
elem = builder.extract_value(items_val, i)
266+
elem_first = builder.extract_value(elem, 0)
267+
elem_second = builder.extract_value(elem, 1)
268+
repacked_items.append((elem_first, elem_second))
269+
d = build_map(context, builder, ret_type, items, repacked_items)
270+
return d
271+
272+
return ret_type(items), codegen
273+
274+
275+
@sdc_overload(dict)
276+
def dict_from_tuples_ovld(x):
277+
278+
accepted_tuple_types = (types.Tuple, types.UniTuple)
279+
if not isinstance(x, types.BaseAnonymousTuple):
280+
return None
281+
282+
def check_tuple_element(ty):
283+
return (isinstance(ty, accepted_tuple_types)
284+
and len(ty) == 2
285+
and isinstance(ty[0], types.StringLiteral))
286+
287+
# below checks that elements are tuples with size 2 and first element is literal string
288+
if not (len(x) != 0 and all(map(check_tuple_element, x))):
289+
assert False, f"Creating LiteralStrKeyDict not supported from pairs of: {x}"
290+
291+
def dict_from_tuples_impl(x):
292+
return literal_dict_ctor(x)
293+
return dict_from_tuples_impl

sdc/rewrites/dict_zip_tuples.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2019-2021, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
from numba.core.rewrites import register_rewrite, Rewrite
28+
from numba.core.ir_utils import guard, get_definition
29+
from numba import errors
30+
from numba.core import ir
31+
32+
from sdc.rewrites.ir_utils import find_operations, import_function
33+
from sdc.functions.tuple_utils import sdc_tuple_zip
34+
35+
36+
@register_rewrite('before-inference')
37+
class RewriteDictZip(Rewrite):
38+
"""
39+
Searches for calls like dict(zip(arg1, arg2)) and replaces zip with sdc_zip.
40+
"""
41+
42+
def match(self, func_ir, block, typemap, calltypes):
43+
44+
self._block = block
45+
self._func_ir = func_ir
46+
self._calls_to_rewrite = set()
47+
48+
# Find all assignments with a RHS expr being a call to dict, and where arg
49+
# is a call to zip and store these ir.Expr for further modification
50+
for inst in find_operations(block=block, op_name='call'):
51+
expr = inst.value
52+
try:
53+
callee = func_ir.infer_constant(expr.func)
54+
except errors.ConstantInferenceError:
55+
continue
56+
57+
if (callee is dict and len(expr.args) == 1):
58+
dict_arg_expr = guard(get_definition, func_ir, expr.args[0])
59+
if (getattr(dict_arg_expr, 'op', None) == 'call'):
60+
called_func = guard(get_definition, func_ir, dict_arg_expr.func)
61+
if (called_func.value is zip and len(dict_arg_expr.args) == 2):
62+
self._calls_to_rewrite.add(dict_arg_expr)
63+
64+
return len(self._calls_to_rewrite) > 0
65+
66+
def apply(self):
67+
"""
68+
Replace call to zip in matched expressions with call to sdc_zip.
69+
"""
70+
new_block = self._block.copy()
71+
new_block.clear()
72+
zip_spec_stmt = import_function(sdc_tuple_zip, new_block, self._func_ir)
73+
for inst in self._block.body:
74+
if isinstance(inst, ir.Assign) and inst.value in self._calls_to_rewrite:
75+
expr = inst.value
76+
expr.func = zip_spec_stmt.target # injects the new function
77+
new_block.append(inst)
78+
return new_block

sdc/tests/test_basic.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
import pandas as pd
3131
import random
3232
import unittest
33+
from itertools import product
34+
3335
from numba import types
36+
from numba.tests.support import MemoryLeakMixin
3437

3538
import sdc
3639
from sdc.tests.test_base import TestCase
@@ -43,7 +46,8 @@
4346
dist_IR_contains,
4447
get_rank,
4548
get_start_end,
46-
skip_numba_jit)
49+
skip_numba_jit,
50+
assert_nbtype_for_varname)
4751

4852

4953
def get_np_state_ptr():
@@ -540,5 +544,47 @@ def test_rhs(arr_len):
540544
np.testing.assert_allclose(A, B)
541545

542546

547+
class TestPython(MemoryLeakMixin, TestCase):
548+
549+
def test_literal_dict_ctor(self):
550+
""" Verifies that dict builtin creates LiteralStrKeyDict from tuple
551+
of pairs ('col_name_i', col_data_i), where col_name_i is literal string """
552+
553+
def test_impl_1():
554+
items = (('A', np.arange(11)), )
555+
res = dict(items)
556+
return len(res)
557+
558+
def test_impl_2():
559+
items = (('A', np.arange(5)), ('B', np.ones(11)), )
560+
res = dict(items)
561+
return len(res)
562+
563+
local_vars = locals()
564+
list_tested_fns = [local_vars[k] for k in local_vars.keys() if k.startswith('test_impl')]
565+
566+
for test_impl in list_tested_fns:
567+
with self.subTest(tested_func_name=test_impl.__name__):
568+
sdc_func = self.jit(test_impl)
569+
self.assertEqual(sdc_func(), test_impl())
570+
assert_nbtype_for_varname(self, sdc_func, 'res', types.LiteralStrKeyDict)
571+
572+
def test_dict_zip_rewrite(self):
573+
""" Verifies that a compination of dict(zip()) creates LiteralStrKeyDict when
574+
zip is applied to tuples of literal column names and columns data """
575+
576+
dict_keys = ('A', 'B')
577+
dict_values = (np.ones(5), np.array([1, 2, 3]))
578+
579+
def test_impl():
580+
res = dict(zip(dict_keys, dict_values))
581+
return len(res)
582+
583+
sdc_func = self.jit(test_impl)
584+
expected = len(dict(zip(dict_keys, dict_values)))
585+
self.assertEqual(sdc_func(), expected)
586+
assert_nbtype_for_varname(self, sdc_func, 'res', types.LiteralStrKeyDict)
587+
588+
543589
if __name__ == "__main__":
544590
unittest.main()

sdc/tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,10 @@ def _make_func_from_text(func_text, func_name='test_impl', global_vars={}):
272272
exec(func_text, global_vars, loc_vars)
273273
test_impl = loc_vars[func_name]
274274
return test_impl
275+
276+
277+
def assert_nbtype_for_varname(self, disp, var, expected_type, fn_sig=None):
278+
fn_sig = fn_sig or disp.nopython_signatures[0]
279+
cres = disp.get_compile_result(fn_sig)
280+
fn_typemap = cres.type_annotation.typemap
281+
self.assertIsInstance(fn_typemap[var], expected_type)

0 commit comments

Comments
 (0)