diff --git a/py/internal.go b/py/internal.go index bf654631..df0e285c 100644 --- a/py/internal.go +++ b/py/internal.go @@ -95,11 +95,12 @@ func Index(a Object) (Int, error) { if err != nil { return 0, err } + if res, ok := A.(Int); ok { return res, nil } - return 0, err + return 0, ExceptionNewf(TypeError, "__index__ returned non-int: (type %s)", A.Type().Name) } return 0, ExceptionNewf(TypeError, "unsupported operand type(s) for index: '%s'", a.Type().Name) diff --git a/py/range.go b/py/range.go index fe31497e..a54005c8 100644 --- a/py/range.go +++ b/py/range.go @@ -160,6 +160,16 @@ func computeRangeLength(start, stop, step Int) Int { return res } +func getIndexWithDefault(i Object, d Int) (Int, error) { + if i == None { + return d, nil + } else if res, err := Index(i); err != nil { + return 0, err + } else { + return res, nil + } +} + func computeNegativeIndex(index, length Int) Int { if index < 0 { index += length @@ -177,19 +187,19 @@ func computeBoundIndex(index, length Int) Int { } func computeRangeSlice(r *Range, s *Slice) (Object, error) { - start, err := Index(s.Start) + start, err := getIndexWithDefault(s.Start, 0) if err != nil { - start = 0 + return nil, err } - stop, err := Index(s.Stop) + stop, err := getIndexWithDefault(s.Stop, r.Length) if err != nil { - stop = r.Length + return nil, err } - - step, err := Index(s.Step) + step, err := getIndexWithDefault(s.Step, 1) if err != nil { - step = 1 + return nil, err } + if step == 0 { return nil, ExceptionNewf(ValueError, "slice step cannot be zero") } diff --git a/py/tests/list.py b/py/tests/list.py index 3e8468b0..c4ffcb00 100644 --- a/py/tests/list.py +++ b/py/tests/list.py @@ -111,4 +111,34 @@ assertRaises(TypeError, lambda: list.sort(s5, key=1)) assertRaises(TypeError, lambda: list.sort(1)) +class Index: + def __index__(self): + return 1 + +a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +b = Index() +assert a[b] == 1 +assert a[b:10] == a[1:10] +assert a[10:b:-1] == a[10:1:-1] + +class NonIntegerIndex: + def __index__(self): + return 1.1 + +a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +b = NonIntegerIndex() +try: + a[b] +except TypeError: + pass +else: + assert False, "TypeError not raised" + +try: + a[b:10] +except TypeError: + pass +else: + assert False, "TypeError not raised" + doc="finished" diff --git a/py/tests/range.py b/py/tests/range.py index 44eca5db..c027b4fe 100644 --- a/py/tests/range.py +++ b/py/tests/range.py @@ -116,4 +116,24 @@ def __index__(self): assert a[b:10] == a[1:10] assert a[10:b:-1] == a[10:1:-1] +class NonIntegerIndex: + def __index__(self): + return 1.1 + +a = range(10) +b = NonIntegerIndex() +try: + a[b] +except TypeError: + pass +else: + assert False, "TypeError not raised" + +try: + a[b:10] +except TypeError: + pass +else: + assert False, "TypeError not raised" + doc="finished" diff --git a/py/tests/string.py b/py/tests/string.py index 8eb3b691..43487328 100644 --- a/py/tests/string.py +++ b/py/tests/string.py @@ -886,4 +886,35 @@ def index(s, i): assert uni[7:7:2] == '' assert uni[7:7:3] == '' +class Index: + def __index__(self): + return 1 + +a = '012345678910' +b = Index() +assert a[b] == '1' +assert a[b:10] == a[1:10] +assert a[10:b:-1] == a[10:1:-1] + +class NonIntegerIndex: + def __index__(self): + return 1.1 + +a = '012345678910' +b = NonIntegerIndex() +try: + a[b] +except TypeError: + pass +else: + assert False, "TypeError not raised" + +try: + a[b:10] +except TypeError: + pass +else: + assert False, "TypeError not raised" + + doc="finished" diff --git a/py/tests/tuple.py b/py/tests/tuple.py index 609a687c..8bc6df20 100644 --- a/py/tests/tuple.py +++ b/py/tests/tuple.py @@ -20,4 +20,34 @@ assert a * 0 == () assert a * -1 == () +class Index: + def __index__(self): + return 1 + +a = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) +b = Index() +assert a[b] == 1 +assert a[b:10] == a[1:10] +assert a[10:b:-1] == a[10:1:-1] + +class NonIntegerIndex: + def __index__(self): + return 1.1 + +a = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) +b = NonIntegerIndex() +try: + a[b] +except TypeError: + pass +else: + assert False, "TypeError not raised" + +try: + a[b:10] +except TypeError: + pass +else: + assert False, "TypeError not raised" + doc="finished"