Skip to content

ENH: Adding type checks to set_printoptions() #7859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions numpy/core/arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
datetime_data)
from .fromnumeric import ravel
from .numeric import asarray
import operator

if sys.version_info[0] >= 3:
_MAXINT = sys.maxsize
Expand Down Expand Up @@ -155,19 +156,19 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None,
global _formatter

if linewidth is not None:
_line_width = linewidth
_line_width = operator.index(linewidth)
if threshold is not None:
_summaryThreshold = threshold
_summaryThreshold = operator.index(threshold)
if edgeitems is not None:
_summaryEdgeItems = edgeitems
_summaryEdgeItems = operator.index(edgeitems)
if precision is not None:
_float_output_precision = precision
_float_output_precision = operator.index(precision)
if suppress is not None:
_float_output_suppress_small = not not suppress
if nanstr is not None:
_nan_str = nanstr
_nan_str = str(nanstr)
if infstr is not None:
_inf_str = infstr
_inf_str = str(infstr)
_formatter = formatter

def get_printoptions():
Expand Down
6 changes: 5 additions & 1 deletion numpy/core/tests/test_arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numpy.compat import sixu
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_equal
TestCase, run_module_suite, assert_, assert_equal, assert_raises
)

class TestArrayRepr(object):
Expand Down Expand Up @@ -157,6 +157,10 @@ def test_formatter_reset(self):
np.set_printoptions(formatter={'float_kind':None})
assert_equal(repr(x), "array([ 0., 1., 2.])")

def test_type_check(self):
d = {}
assert_raises(TypeError, np.set_printoptions, d)

def test_unicode_object_array():
import sys
if sys.version_info[0] >= 3:
Expand Down