diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a445ed1..62a451a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -118,6 +118,7 @@ jobs: python -m pip install -U setuptools python -m pip install -U wheel python -m pip install "numpy==${{ matrix.numpy-version }}" + python -m pip install -U pytest python -m pip install -U mypy - name: Install diff --git a/quantities/registry.py b/quantities/registry.py index 4f9f3c7..7603965 100644 --- a/quantities/registry.py +++ b/quantities/registry.py @@ -1,41 +1,50 @@ """ """ +import ast import re -import builtins class UnitRegistry: + # Note that this structure ensures that UnitRegistry behaves as a singleton class __Registry: __shared_state = {} + whitelist = ( + ast.Expression, + ast.Constant, + ast.Name, + ast.Load, + ast.BinOp, + ast.UnaryOp, + ast.operator, + ast.unaryop, + ) def __init__(self): self.__dict__ = self.__shared_state self.__context = {} def __getitem__(self, string): - - # easy hack to prevent arbitrary evaluation of code - all_builtins = dir(builtins) - # because we have kilobytes, other bytes we have to remove bytes - all_builtins.remove("bytes") - # have to deal with octet as well - all_builtins.remove("oct") - # have to remove min which is short for minute - all_builtins.remove("min") - for builtin in all_builtins: - if builtin in string: - raise RuntimeError(f"String parsing error for `{string}`. Enter a string accepted by quantities") - - try: - return eval(string, self.__context) - except NameError: + # This approach to avoiding arbitrary evaluation of code is based on https://stackoverflow.com/a/11952618 + # by https://stackoverflow.com/users/567292/ecatmur + tree = ast.parse(string, mode="eval") + valid = all(isinstance(node, self.whitelist) for node in ast.walk(tree)) + if valid: + try: + item = eval( + compile(tree, filename="", mode="eval"), + {"__builtins__": {}}, + self.__context, + ) + except NameError: + raise LookupError('Unable to parse units: "%s"' % string) + else: + return item + else: # could return self['UnitQuantity'](string) - raise LookupError( - 'Unable to parse units: "%s"'%string - ) + raise LookupError('Unable to parse units: "%s"' % string) def __setitem__(self, string, val): assert isinstance(string, str) diff --git a/quantities/tests/test_units.py b/quantities/tests/test_units.py index 2bbbf93..1912a3a 100644 --- a/quantities/tests/test_units.py +++ b/quantities/tests/test_units.py @@ -1,6 +1,9 @@ +import pytest + from .. import units as pq from .common import TestCase + class TestUnits(TestCase): def test_compound_units(self): @@ -30,3 +33,8 @@ def test_units_copy(self): self.assertQuantityEqual(pq.m.copy(), pq.m) pc_per_cc = pq.CompoundUnit("pc/cm**3") self.assertQuantityEqual(pc_per_cc.copy(), pc_per_cc) + + def test_code_injection(self): + with pytest.raises(LookupError) as exc_info: + pq.CompoundUnit("exec(\"print('Hello there.')\\nprint('General Wasabi!')\")") + assert "Wasabi" in str(exc_info.value)