diff --git a/Lib/symtable.py b/Lib/symtable.py new file mode 100644 index 0000000000..672ec0ce1f --- /dev/null +++ b/Lib/symtable.py @@ -0,0 +1,414 @@ +"""Interface to the compiler's internal symbol tables""" + +import _symtable +from _symtable import (USE, DEF_GLOBAL, DEF_NONLOCAL, DEF_LOCAL, DEF_PARAM, + DEF_IMPORT, DEF_BOUND, DEF_ANNOT, SCOPE_OFF, SCOPE_MASK, FREE, + LOCAL, GLOBAL_IMPLICIT, GLOBAL_EXPLICIT, CELL) + +import weakref +from enum import StrEnum + +__all__ = ["symtable", "SymbolTableType", "SymbolTable", "Class", "Function", "Symbol"] + +def symtable(code, filename, compile_type): + """ Return the toplevel *SymbolTable* for the source code. + + *filename* is the name of the file with the code + and *compile_type* is the *compile()* mode argument. + """ + top = _symtable.symtable(code, filename, compile_type) + return _newSymbolTable(top, filename) + +class SymbolTableFactory: + def __init__(self): + self.__memo = weakref.WeakValueDictionary() + + def new(self, table, filename): + if table.type == _symtable.TYPE_FUNCTION: + return Function(table, filename) + if table.type == _symtable.TYPE_CLASS: + return Class(table, filename) + return SymbolTable(table, filename) + + def __call__(self, table, filename): + key = table, filename + obj = self.__memo.get(key, None) + if obj is None: + obj = self.__memo[key] = self.new(table, filename) + return obj + +_newSymbolTable = SymbolTableFactory() + + +class SymbolTableType(StrEnum): + MODULE = "module" + FUNCTION = "function" + CLASS = "class" + ANNOTATION = "annotation" + TYPE_ALIAS = "type alias" + TYPE_PARAMETERS = "type parameters" + TYPE_VARIABLE = "type variable" + + +class SymbolTable: + + def __init__(self, raw_table, filename): + self._table = raw_table + self._filename = filename + self._symbols = {} + + def __repr__(self): + if self.__class__ == SymbolTable: + kind = "" + else: + kind = "%s " % self.__class__.__name__ + + if self._table.name == "top": + return "<{0}SymbolTable for module {1}>".format(kind, self._filename) + else: + return "<{0}SymbolTable for {1} in {2}>".format(kind, + self._table.name, + self._filename) + + def get_type(self): + """Return the type of the symbol table. + + The value returned is one of the values in + the ``SymbolTableType`` enumeration. + """ + if self._table.type == _symtable.TYPE_MODULE: + return SymbolTableType.MODULE + if self._table.type == _symtable.TYPE_FUNCTION: + return SymbolTableType.FUNCTION + if self._table.type == _symtable.TYPE_CLASS: + return SymbolTableType.CLASS + if self._table.type == _symtable.TYPE_ANNOTATION: + return SymbolTableType.ANNOTATION + if self._table.type == _symtable.TYPE_TYPE_ALIAS: + return SymbolTableType.TYPE_ALIAS + if self._table.type == _symtable.TYPE_TYPE_PARAMETERS: + return SymbolTableType.TYPE_PARAMETERS + if self._table.type == _symtable.TYPE_TYPE_VARIABLE: + return SymbolTableType.TYPE_VARIABLE + assert False, f"unexpected type: {self._table.type}" + + def get_id(self): + """Return an identifier for the table. + """ + return self._table.id + + def get_name(self): + """Return the table's name. + + This corresponds to the name of the class, function + or 'top' if the table is for a class, function or + global respectively. + """ + return self._table.name + + def get_lineno(self): + """Return the number of the first line in the + block for the table. + """ + return self._table.lineno + + def is_optimized(self): + """Return *True* if the locals in the table + are optimizable. + """ + return bool(self._table.type == _symtable.TYPE_FUNCTION) + + def is_nested(self): + """Return *True* if the block is a nested class + or function.""" + return bool(self._table.nested) + + def has_children(self): + """Return *True* if the block has nested namespaces. + """ + return bool(self._table.children) + + def get_identifiers(self): + """Return a view object containing the names of symbols in the table. + """ + return self._table.symbols.keys() + + def lookup(self, name): + """Lookup a *name* in the table. + + Returns a *Symbol* instance. + """ + sym = self._symbols.get(name) + if sym is None: + flags = self._table.symbols[name] + namespaces = self.__check_children(name) + module_scope = (self._table.name == "top") + sym = self._symbols[name] = Symbol(name, flags, namespaces, + module_scope=module_scope) + return sym + + def get_symbols(self): + """Return a list of *Symbol* instances for + names in the table. + """ + return [self.lookup(ident) for ident in self.get_identifiers()] + + def __check_children(self, name): + return [_newSymbolTable(st, self._filename) + for st in self._table.children + if st.name == name] + + def get_children(self): + """Return a list of the nested symbol tables. + """ + return [_newSymbolTable(st, self._filename) + for st in self._table.children] + + +class Function(SymbolTable): + + # Default values for instance variables + __params = None + __locals = None + __frees = None + __globals = None + __nonlocals = None + + def __idents_matching(self, test_func): + return tuple(ident for ident in self.get_identifiers() + if test_func(self._table.symbols[ident])) + + def get_parameters(self): + """Return a tuple of parameters to the function. + """ + if self.__params is None: + self.__params = self.__idents_matching(lambda x:x & DEF_PARAM) + return self.__params + + def get_locals(self): + """Return a tuple of locals in the function. + """ + if self.__locals is None: + locs = (LOCAL, CELL) + test = lambda x: ((x >> SCOPE_OFF) & SCOPE_MASK) in locs + self.__locals = self.__idents_matching(test) + return self.__locals + + def get_globals(self): + """Return a tuple of globals in the function. + """ + if self.__globals is None: + glob = (GLOBAL_IMPLICIT, GLOBAL_EXPLICIT) + test = lambda x:((x >> SCOPE_OFF) & SCOPE_MASK) in glob + self.__globals = self.__idents_matching(test) + return self.__globals + + def get_nonlocals(self): + """Return a tuple of nonlocals in the function. + """ + if self.__nonlocals is None: + self.__nonlocals = self.__idents_matching(lambda x:x & DEF_NONLOCAL) + return self.__nonlocals + + def get_frees(self): + """Return a tuple of free variables in the function. + """ + if self.__frees is None: + is_free = lambda x:((x >> SCOPE_OFF) & SCOPE_MASK) == FREE + self.__frees = self.__idents_matching(is_free) + return self.__frees + + +class Class(SymbolTable): + + __methods = None + + def get_methods(self): + """Return a tuple of methods declared in the class. + """ + if self.__methods is None: + d = {} + + def is_local_symbol(ident): + flags = self._table.symbols.get(ident, 0) + return ((flags >> SCOPE_OFF) & SCOPE_MASK) == LOCAL + + for st in self._table.children: + # pick the function-like symbols that are local identifiers + if is_local_symbol(st.name): + match st.type: + case _symtable.TYPE_FUNCTION: + # generators are of type TYPE_FUNCTION with a ".0" + # parameter as a first parameter (which makes them + # distinguishable from a function named 'genexpr') + if st.name == 'genexpr' and '.0' in st.varnames: + continue + d[st.name] = 1 + case _symtable.TYPE_TYPE_PARAMETERS: + # Get the function-def block in the annotation + # scope 'st' with the same identifier, if any. + scope_name = st.name + for c in st.children: + if c.name == scope_name and c.type == _symtable.TYPE_FUNCTION: + # A generic generator of type TYPE_FUNCTION + # cannot be a direct child of 'st' (but it + # can be a descendant), e.g.: + # + # class A: + # type genexpr[genexpr] = (x for x in []) + assert scope_name != 'genexpr' or '.0' not in c.varnames + d[scope_name] = 1 + break + self.__methods = tuple(d) + return self.__methods + + +class Symbol: + + def __init__(self, name, flags, namespaces=None, *, module_scope=False): + self.__name = name + self.__flags = flags + self.__scope = (flags >> SCOPE_OFF) & SCOPE_MASK # like PyST_GetScope() + self.__namespaces = namespaces or () + self.__module_scope = module_scope + + def __repr__(self): + flags_str = '|'.join(self._flags_str()) + return f'' + + def _scope_str(self): + return _scopes_value_to_name.get(self.__scope) or str(self.__scope) + + def _flags_str(self): + for flagname, flagvalue in _flags: + if self.__flags & flagvalue == flagvalue: + yield flagname + + def get_name(self): + """Return a name of a symbol. + """ + return self.__name + + def is_referenced(self): + """Return *True* if the symbol is used in + its block. + """ + return bool(self.__flags & _symtable.USE) + + def is_parameter(self): + """Return *True* if the symbol is a parameter. + """ + return bool(self.__flags & DEF_PARAM) + + def is_global(self): + """Return *True* if the symbol is global. + """ + return bool(self.__scope in (GLOBAL_IMPLICIT, GLOBAL_EXPLICIT) + or (self.__module_scope and self.__flags & DEF_BOUND)) + + def is_nonlocal(self): + """Return *True* if the symbol is nonlocal.""" + return bool(self.__flags & DEF_NONLOCAL) + + def is_declared_global(self): + """Return *True* if the symbol is declared global + with a global statement.""" + return bool(self.__scope == GLOBAL_EXPLICIT) + + def is_local(self): + """Return *True* if the symbol is local. + """ + return bool(self.__scope in (LOCAL, CELL) + or (self.__module_scope and self.__flags & DEF_BOUND)) + + def is_annotated(self): + """Return *True* if the symbol is annotated. + """ + return bool(self.__flags & DEF_ANNOT) + + def is_free(self): + """Return *True* if a referenced symbol is + not assigned to. + """ + return bool(self.__scope == FREE) + + def is_imported(self): + """Return *True* if the symbol is created from + an import statement. + """ + return bool(self.__flags & DEF_IMPORT) + + def is_assigned(self): + """Return *True* if a symbol is assigned to.""" + return bool(self.__flags & DEF_LOCAL) + + def is_namespace(self): + """Returns *True* if name binding introduces new namespace. + + If the name is used as the target of a function or class + statement, this will be true. + + Note that a single name can be bound to multiple objects. If + is_namespace() is true, the name may also be bound to other + objects, like an int or list, that does not introduce a new + namespace. + """ + return bool(self.__namespaces) + + def get_namespaces(self): + """Return a list of namespaces bound to this name""" + return self.__namespaces + + def get_namespace(self): + """Return the single namespace bound to this name. + + Raises ValueError if the name is bound to multiple namespaces + or no namespace. + """ + if len(self.__namespaces) == 0: + raise ValueError("name is not bound to any namespaces") + elif len(self.__namespaces) > 1: + raise ValueError("name is bound to multiple namespaces") + else: + return self.__namespaces[0] + + +_flags = [('USE', USE)] +_flags.extend(kv for kv in globals().items() if kv[0].startswith('DEF_')) +_scopes_names = ('FREE', 'LOCAL', 'GLOBAL_IMPLICIT', 'GLOBAL_EXPLICIT', 'CELL') +_scopes_value_to_name = {globals()[n]: n for n in _scopes_names} + + +def main(args): + import sys + def print_symbols(table, level=0): + indent = ' ' * level + nested = "nested " if table.is_nested() else "" + if table.get_type() == 'module': + what = f'from file {table._filename!r}' + else: + what = f'{table.get_name()!r}' + print(f'{indent}symbol table for {nested}{table.get_type()} {what}:') + for ident in table.get_identifiers(): + symbol = table.lookup(ident) + flags = ', '.join(symbol._flags_str()).lower() + print(f' {indent}{symbol._scope_str().lower()} symbol {symbol.get_name()!r}: {flags}') + print() + + for table2 in table.get_children(): + print_symbols(table2, level + 1) + + for filename in args or ['-']: + if filename == '-': + src = sys.stdin.read() + filename = '' + else: + with open(filename, 'rb') as f: + src = f.read() + mod = symtable(src, filename, 'exec') + print_symbols(mod) + + +if __name__ == "__main__": + import sys + main(sys.argv[1:]) diff --git a/Lib/test/test_symtable.py b/Lib/test/test_symtable.py index ca62b9d685..8f1a80a524 100644 --- a/Lib/test/test_symtable.py +++ b/Lib/test/test_symtable.py @@ -1,9 +1,13 @@ """ Test the API of the symtable module. """ + +import textwrap import symtable import unittest +from test import support +from test.support import os_helper TEST_CODE = """ @@ -11,7 +15,7 @@ glob = 42 some_var = 12 -some_non_assigned_global_var = 11 +some_non_assigned_global_var: int some_assigned_global_var = 11 class Mine: @@ -40,6 +44,129 @@ def foo(): def namespace_test(): pass def namespace_test(): pass + +type Alias = int +type GenericAlias[T] = list[T] + +def generic_spam[T](a): + pass + +class GenericMine[T: int, U: (int, str) = int]: + pass +""" + +TEST_COMPLEX_CLASS_CODE = """ +# The following symbols are defined in ComplexClass +# without being introduced by a 'global' statement. +glob_unassigned_meth: Any +glob_unassigned_meth_pep_695: Any + +glob_unassigned_async_meth: Any +glob_unassigned_async_meth_pep_695: Any + +def glob_assigned_meth(): pass +def glob_assigned_meth_pep_695[T](): pass + +async def glob_assigned_async_meth(): pass +async def glob_assigned_async_meth_pep_695[T](): pass + +# The following symbols are defined in ComplexClass after +# being introduced by a 'global' statement (and therefore +# are not considered as local symbols of ComplexClass). +glob_unassigned_meth_ignore: Any +glob_unassigned_meth_pep_695_ignore: Any + +glob_unassigned_async_meth_ignore: Any +glob_unassigned_async_meth_pep_695_ignore: Any + +def glob_assigned_meth_ignore(): pass +def glob_assigned_meth_pep_695_ignore[T](): pass + +async def glob_assigned_async_meth_ignore(): pass +async def glob_assigned_async_meth_pep_695_ignore[T](): pass + +class ComplexClass: + a_var = 1234 + a_genexpr = (x for x in []) + a_lambda = lambda x: x + + type a_type_alias = int + type a_type_alias_pep_695[T] = list[T] + + class a_class: pass + class a_class_pep_695[T]: pass + + def a_method(self): pass + def a_method_pep_695[T](self): pass + + async def an_async_method(self): pass + async def an_async_method_pep_695[T](self): pass + + @classmethod + def a_classmethod(cls): pass + @classmethod + def a_classmethod_pep_695[T](self): pass + + @classmethod + async def an_async_classmethod(cls): pass + @classmethod + async def an_async_classmethod_pep_695[T](self): pass + + @staticmethod + def a_staticmethod(): pass + @staticmethod + def a_staticmethod_pep_695[T](self): pass + + @staticmethod + async def an_async_staticmethod(): pass + @staticmethod + async def an_async_staticmethod_pep_695[T](self): pass + + # These ones will be considered as methods because of the 'def' although + # they are *not* valid methods at runtime since they are not decorated + # with @staticmethod. + def a_fakemethod(): pass + def a_fakemethod_pep_695[T](): pass + + async def an_async_fakemethod(): pass + async def an_async_fakemethod_pep_695[T](): pass + + # Check that those are still considered as methods + # since they are not using the 'global' keyword. + def glob_unassigned_meth(): pass + def glob_unassigned_meth_pep_695[T](): pass + + async def glob_unassigned_async_meth(): pass + async def glob_unassigned_async_meth_pep_695[T](): pass + + def glob_assigned_meth(): pass + def glob_assigned_meth_pep_695[T](): pass + + async def glob_assigned_async_meth(): pass + async def glob_assigned_async_meth_pep_695[T](): pass + + # The following are not picked as local symbols because they are not + # visible by the class at runtime (this is equivalent to having the + # definitions outside of the class). + global glob_unassigned_meth_ignore + def glob_unassigned_meth_ignore(): pass + global glob_unassigned_meth_pep_695_ignore + def glob_unassigned_meth_pep_695_ignore[T](): pass + + global glob_unassigned_async_meth_ignore + async def glob_unassigned_async_meth_ignore(): pass + global glob_unassigned_async_meth_pep_695_ignore + async def glob_unassigned_async_meth_pep_695_ignore[T](): pass + + global glob_assigned_meth_ignore + def glob_assigned_meth_ignore(): pass + global glob_assigned_meth_pep_695_ignore + def glob_assigned_meth_pep_695_ignore[T](): pass + + global glob_assigned_async_meth_ignore + async def glob_assigned_async_meth_ignore(): pass + global glob_assigned_async_meth_pep_695_ignore + async def glob_assigned_async_meth_pep_695_ignore[T](): pass """ @@ -54,18 +181,46 @@ class SymtableTest(unittest.TestCase): top = symtable.symtable(TEST_CODE, "?", "exec") # These correspond to scopes in TEST_CODE Mine = find_block(top, "Mine") + a_method = find_block(Mine, "a_method") spam = find_block(top, "spam") internal = find_block(spam, "internal") other_internal = find_block(spam, "other_internal") foo = find_block(top, "foo") + Alias = find_block(top, "Alias") + GenericAlias = find_block(top, "GenericAlias") + # XXX: RUSTPYTHON + # GenericAlias_inner = find_block(GenericAlias, "GenericAlias") + generic_spam = find_block(top, "generic_spam") + # XXX: RUSTPYTHON + # generic_spam_inner = find_block(generic_spam, "generic_spam") + GenericMine = find_block(top, "GenericMine") + # XXX: RUSTPYTHON + # GenericMine_inner = find_block(GenericMine, "GenericMine") + # XXX: RUSTPYTHON + # T = find_block(GenericMine, "T") + # XXX: RUSTPYTHON + # U = find_block(GenericMine, "U") + + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_type(self): self.assertEqual(self.top.get_type(), "module") self.assertEqual(self.Mine.get_type(), "class") self.assertEqual(self.a_method.get_type(), "function") self.assertEqual(self.spam.get_type(), "function") self.assertEqual(self.internal.get_type(), "function") + self.assertEqual(self.foo.get_type(), "function") + self.assertEqual(self.Alias.get_type(), "type alias") + self.assertEqual(self.GenericAlias.get_type(), "type parameters") + self.assertEqual(self.GenericAlias_inner.get_type(), "type alias") + self.assertEqual(self.generic_spam.get_type(), "type parameters") + self.assertEqual(self.generic_spam_inner.get_type(), "function") + self.assertEqual(self.GenericMine.get_type(), "type parameters") + self.assertEqual(self.GenericMine_inner.get_type(), "class") + self.assertEqual(self.T.get_type(), "type variable") + self.assertEqual(self.U.get_type(), "type variable") # TODO: RUSTPYTHON @unittest.expectedFailure @@ -75,6 +230,11 @@ def test_id(self): self.assertGreater(self.a_method.get_id(), 0) self.assertGreater(self.spam.get_id(), 0) self.assertGreater(self.internal.get_id(), 0) + self.assertGreater(self.foo.get_id(), 0) + self.assertGreater(self.Alias.get_id(), 0) + self.assertGreater(self.GenericAlias.get_id(), 0) + self.assertGreater(self.generic_spam.get_id(), 0) + self.assertGreater(self.GenericMine.get_id(), 0) def test_optimized(self): self.assertFalse(self.top.is_optimized()) @@ -106,6 +266,8 @@ def test_function_info(self): self.assertEqual(sorted(func.get_globals()), ["bar", "glob", "some_assigned_global_var"]) self.assertEqual(self.internal.get_frees(), ("x",)) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_globals(self): self.assertTrue(self.spam.lookup("glob").is_global()) self.assertFalse(self.spam.lookup("glob").is_declared_global()) @@ -126,6 +288,8 @@ def test_nonlocal(self): expected = ("some_var",) self.assertEqual(self.other_internal.get_nonlocals(), expected) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_local(self): self.assertTrue(self.spam.lookup("x").is_local()) self.assertFalse(self.spam.lookup("bar").is_local()) @@ -133,9 +297,13 @@ def test_local(self): self.assertTrue(self.top.lookup("some_non_assigned_global_var").is_local()) self.assertTrue(self.top.lookup("some_assigned_global_var").is_local()) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_free(self): self.assertTrue(self.internal.lookup("x").is_free()) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_referenced(self): self.assertTrue(self.internal.lookup("x").is_referenced()) self.assertTrue(self.spam.lookup("internal").is_referenced()) @@ -167,6 +335,10 @@ def test_namespaces(self): self.assertEqual(len(ns_test.get_namespaces()), 2) self.assertRaises(ValueError, ns_test.get_namespace) + ns_test_2 = self.top.lookup("glob") + self.assertEqual(len(ns_test_2.get_namespaces()), 0) + self.assertRaises(ValueError, ns_test_2.get_namespace) + def test_assigned(self): self.assertTrue(self.spam.lookup("x").is_assigned()) self.assertTrue(self.spam.lookup("bar").is_assigned()) @@ -174,6 +346,8 @@ def test_assigned(self): self.assertTrue(self.Mine.lookup("a_method").is_assigned()) self.assertFalse(self.internal.lookup("x").is_assigned()) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_annotated(self): st1 = symtable.symtable('def f():\n x: int\n', 'test', 'exec') st2 = st1.get_children()[0] @@ -199,6 +373,8 @@ def test_annotated(self): ' x: int', 'test', 'exec') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_imported(self): self.assertTrue(self.top.lookup("sys").is_imported()) @@ -210,9 +386,79 @@ def test_name(self): # TODO: RUSTPYTHON @unittest.expectedFailure - def test_class_info(self): + def test_class_get_methods(self): self.assertEqual(self.Mine.get_methods(), ('a_method',)) + top = symtable.symtable(TEST_COMPLEX_CLASS_CODE, "?", "exec") + this = find_block(top, "ComplexClass") + + self.assertEqual(this.get_methods(), ( + 'a_method', 'a_method_pep_695', + 'an_async_method', 'an_async_method_pep_695', + 'a_classmethod', 'a_classmethod_pep_695', + 'an_async_classmethod', 'an_async_classmethod_pep_695', + 'a_staticmethod', 'a_staticmethod_pep_695', + 'an_async_staticmethod', 'an_async_staticmethod_pep_695', + 'a_fakemethod', 'a_fakemethod_pep_695', + 'an_async_fakemethod', 'an_async_fakemethod_pep_695', + 'glob_unassigned_meth', 'glob_unassigned_meth_pep_695', + 'glob_unassigned_async_meth', 'glob_unassigned_async_meth_pep_695', + 'glob_assigned_meth', 'glob_assigned_meth_pep_695', + 'glob_assigned_async_meth', 'glob_assigned_async_meth_pep_695', + )) + + # Test generator expressions that are of type TYPE_FUNCTION + # but will not be reported by get_methods() since they are + # not functions per se. + # + # Other kind of comprehensions such as list, set or dict + # expressions do not have the TYPE_FUNCTION type. + + def check_body(body, expected_methods): + indented = textwrap.indent(body, ' ' * 4) + top = symtable.symtable(f"class A:\n{indented}", "?", "exec") + this = find_block(top, "A") + self.assertEqual(this.get_methods(), expected_methods) + + # statements with 'genexpr' inside it + GENEXPRS = ( + 'x = (x for x in [])', + 'x = (x async for x in [])', + 'type x[genexpr = (x for x in [])] = (x for x in [])', + 'type x[genexpr = (x async for x in [])] = (x async for x in [])', + 'genexpr = (x for x in [])', + 'genexpr = (x async for x in [])', + 'type genexpr[genexpr = (x for x in [])] = (x for x in [])', + 'type genexpr[genexpr = (x async for x in [])] = (x async for x in [])', + ) + + for gen in GENEXPRS: + # test generator expression + with self.subTest(gen=gen): + check_body(gen, ()) + + # test generator expression + variable named 'genexpr' + with self.subTest(gen=gen, isvar=True): + check_body('\n'.join((gen, 'genexpr = 1')), ()) + check_body('\n'.join(('genexpr = 1', gen)), ()) + + for paramlist in ('()', '(x)', '(x, y)', '(z: T)'): + for func in ( + f'def genexpr{paramlist}:pass', + f'async def genexpr{paramlist}:pass', + f'def genexpr[T]{paramlist}:pass', + f'async def genexpr[T]{paramlist}:pass', + ): + with self.subTest(func=func): + # test function named 'genexpr' + check_body(func, ('genexpr',)) + + for gen in GENEXPRS: + with self.subTest(gen=gen, func=func): + # test generator expression + function named 'genexpr' + check_body('\n'.join((gen, func)), ('genexpr',)) + check_body('\n'.join((func, gen)), ('genexpr',)) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_filename_correct(self): @@ -230,10 +476,9 @@ def checkfilename(brokencode, offset): checkfilename("def f(x): foo)(", 14) # parse-time checkfilename("def f(x): global x", 11) # symtable-build-time symtable.symtable("pass", b"spam", "exec") - with self.assertWarns(DeprecationWarning), \ - self.assertRaises(TypeError): + with self.assertRaises(TypeError): symtable.symtable("pass", bytearray(b"spam"), "exec") - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): symtable.symtable("pass", memoryview(b"spam"), "exec") with self.assertRaises(TypeError): symtable.symtable("pass", list(b"spam"), "exec") @@ -258,12 +503,71 @@ def test_bytes(self): top = symtable.symtable(code, "?", "exec") self.assertIsNotNone(find_block(top, "\u017d")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_symtable_repr(self): self.assertEqual(str(self.top), "") self.assertEqual(str(self.spam), "") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_symbol_repr(self): + self.assertEqual(repr(self.spam.lookup("glob")), + "") + self.assertEqual(repr(self.spam.lookup("bar")), + "") + self.assertEqual(repr(self.spam.lookup("a")), + "") + self.assertEqual(repr(self.spam.lookup("internal")), + "") + self.assertEqual(repr(self.spam.lookup("other_internal")), + "") + self.assertEqual(repr(self.internal.lookup("x")), + "") + self.assertEqual(repr(self.other_internal.lookup("some_var")), + "") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_symtable_entry_repr(self): + expected = f"" + self.assertEqual(repr(self.top._table), expected) + + +class CommandLineTest(unittest.TestCase): + maxDiff = None + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_file(self): + filename = os_helper.TESTFN + self.addCleanup(os_helper.unlink, filename) + with open(filename, 'w') as f: + f.write(TEST_CODE) + with support.captured_stdout() as stdout: + symtable.main([filename]) + out = stdout.getvalue() + self.assertIn('\n\n', out) + self.assertNotIn('\n\n\n', out) + lines = out.splitlines() + self.assertIn(f"symbol table for module from file {filename!r}:", lines) + self.assertIn(" local symbol 'glob': def_local", lines) + self.assertIn(" global_implicit symbol 'glob': use", lines) + self.assertIn(" local symbol 'spam': def_local", lines) + self.assertIn(" symbol table for function 'spam':", lines) + + def test_stdin(self): + with support.captured_stdin() as stdin: + stdin.write(TEST_CODE) + stdin.seek(0) + with support.captured_stdout() as stdout: + symtable.main([]) + out = stdout.getvalue() + stdin.seek(0) + with support.captured_stdout() as stdout: + symtable.main(['-']) + self.assertEqual(stdout.getvalue(), out) + lines = out.splitlines() + self.assertIn("symbol table for module from file '':", lines) + if __name__ == '__main__': unittest.main() diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index c2f17fd00b..60cb09f405 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -108,7 +108,7 @@ pub fn get_module_inits() -> StdlibMap { // compiler related modules: #[cfg(feature = "compiler")] { - "symtable" => symtable::make_module, + "_symtable" => symtable::make_module, } #[cfg(any(unix, target_os = "wasi"))] { diff --git a/vm/src/stdlib/symtable.rs b/vm/src/stdlib/symtable.rs index 0ee137642c..8a14285778 100644 --- a/vm/src/stdlib/symtable.rs +++ b/vm/src/stdlib/symtable.rs @@ -3,13 +3,111 @@ pub(crate) use symtable::make_module; #[pymodule] mod symtable { use crate::{ - PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyStrRef, compiler, + PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyDictRef, PyStrRef}, + compiler, }; use rustpython_codegen::symboltable::{ CompilerScope, Symbol, SymbolFlags, SymbolScope, SymbolTable, }; use std::fmt; + // Consts as defined at + // https://github.com/python/cpython/blob/6cb20a219a860eaf687b2d968b41c480c7461909/Include/internal/pycore_symtable.h#L156 + + #[pyattr] + pub const DEF_GLOBAL: i32 = 1; + + #[pyattr] + pub const DEF_LOCAL: i32 = 2; + + #[pyattr] + pub const DEF_PARAM: i32 = 2 << 1; + + #[pyattr] + pub const DEF_NONLOCAL: i32 = 2 << 2; + + #[pyattr] + pub const USE: i32 = 2 << 3; + + #[pyattr] + pub const DEF_FREE: i32 = 2 << 4; + + #[pyattr] + pub const DEF_FREE_CLASS: i32 = 2 << 5; + + #[pyattr] + pub const DEF_IMPORT: i32 = 2 << 6; + + #[pyattr] + pub const DEF_ANNOT: i32 = 2 << 7; + + #[pyattr] + pub const DEF_COMP_ITER: i32 = 2 << 8; + + #[pyattr] + pub const DEF_TYPE_PARAM: i32 = 2 << 9; + + #[pyattr] + pub const DEF_COMP_CELL: i32 = 2 << 10; + + #[pyattr] + pub const DEF_BOUND: i32 = DEF_LOCAL | DEF_PARAM | DEF_IMPORT; + + #[pyattr] + pub const SCOPE_OFFSET: i32 = 12; + + #[pyattr] + pub const SCOPE_MASK: i32 = DEF_GLOBAL | DEF_LOCAL | DEF_PARAM | DEF_NONLOCAL; + + #[pyattr] + pub const LOCAL: i32 = 1; + + #[pyattr] + pub const GLOBAL_EXPLICIT: i32 = 2; + + #[pyattr] + pub const GLOBAL_IMPLICIT: i32 = 3; + + #[pyattr] + pub const FREE: i32 = 4; + + #[pyattr] + pub const CELL: i32 = 5; + + #[pyattr] + pub const GENERATOR: i32 = 1; + + #[pyattr] + pub const GENERATOR_EXPRESSION: i32 = 2; + + #[pyattr] + pub const SCOPE_OFF: i32 = SCOPE_OFFSET; + + #[pyattr] + pub const TYPE_FUNCTION: i32 = 0; + + #[pyattr] + pub const TYPE_CLASS: i32 = 1; + + #[pyattr] + pub const TYPE_MODULE: i32 = 2; + + #[pyattr] + pub const TYPE_ANNOTATION: i32 = 3; + + #[pyattr] + pub const TYPE_TYPE_VAR_BOUND: i32 = 4; + + #[pyattr] + pub const TYPE_TYPE_ALIAS: i32 = 5; + + #[pyattr] + pub const TYPE_TYPE_PARAMETERS: i32 = 6; + + #[pyattr] + pub const TYPE_TYPE_VARIABLE: i32 = 7; + #[pyfunction] fn symtable( source: PyStrRef, @@ -48,57 +146,45 @@ mod symtable { #[pyclass] impl PySymbolTable { - #[pymethod] - fn get_name(&self) -> String { + #[pygetset] + fn name(&self) -> String { self.symtable.name.clone() } - #[pymethod] - fn get_type(&self) -> String { - self.symtable.typ.to_string() + #[pygetset(name = "type")] + fn typ(&self) -> i32 { + match self.symtable.typ { + CompilerScope::Function => TYPE_FUNCTION, + CompilerScope::Class => TYPE_CLASS, + CompilerScope::Module => TYPE_MODULE, + CompilerScope::TypeParams => TYPE_TYPE_PARAMETERS, + _ => -1, // TODO: missing types from the C implementation + } } - #[pymethod] - const fn get_lineno(&self) -> u32 { + #[pygetset] + const fn lineno(&self) -> u32 { self.symtable.line_number } - #[pymethod] - const fn is_nested(&self) -> bool { - self.symtable.is_nested - } - - #[pymethod] - const fn is_optimized(&self) -> bool { - matches!( - self.symtable.typ, - CompilerScope::Function | CompilerScope::AsyncFunction - ) + #[pygetset] + fn children(&self, vm: &VirtualMachine) -> PyResult> { + let children = self + .symtable + .sub_tables + .iter() + .map(|t| to_py_symbol_table(t.clone()).into_pyobject(vm)) + .collect(); + Ok(children) } - #[pymethod] - fn lookup(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult> { - let name = name.as_str(); - if let Some(symbol) = self.symtable.symbols.get(name) { - Ok(PySymbol { - symbol: symbol.clone(), - namespaces: self - .symtable - .sub_tables - .iter() - .filter(|table| table.name == name) - .cloned() - .collect(), - is_top_scope: self.symtable.name == "top", - } - .into_ref(&vm.ctx)) - } else { - Err(vm.new_key_error(vm.ctx.new_str(format!("lookup {name} failed")).into())) - } + #[pygetset] + fn id(&self) -> usize { + self as *const Self as *const std::ffi::c_void as usize } - #[pymethod] - fn get_identifiers(&self, vm: &VirtualMachine) -> PyResult> { + #[pygetset] + fn identifiers(&self, vm: &VirtualMachine) -> PyResult> { let symbols = self .symtable .symbols @@ -108,45 +194,19 @@ mod symtable { Ok(symbols) } - #[pymethod] - fn get_symbols(&self, vm: &VirtualMachine) -> PyResult> { - let symbols = self - .symtable - .symbols - .values() - .map(|s| { - (PySymbol { - symbol: s.clone(), - namespaces: self - .symtable - .sub_tables - .iter() - .filter(|&table| table.name == s.name) - .cloned() - .collect(), - is_top_scope: self.symtable.name == "top", - }) - .into_ref(&vm.ctx) - .into() - }) - .collect(); - Ok(symbols) - } - - #[pymethod] - const fn has_children(&self) -> bool { - !self.symtable.sub_tables.is_empty() + #[pygetset] + fn symbols(&self, vm: &VirtualMachine) -> PyResult { + let dict = vm.ctx.new_dict(); + for (name, symbol) in &self.symtable.symbols { + dict.set_item(name, vm.new_pyobj(symbol.flags.bits()), vm) + .unwrap(); + } + Ok(dict) } - #[pymethod] - fn get_children(&self, vm: &VirtualMachine) -> PyResult> { - let children = self - .symtable - .sub_tables - .iter() - .map(|t| to_py_symbol_table(t.clone()).into_pyobject(vm)) - .collect(); - Ok(children) + #[pygetset] + const fn nested(&self) -> bool { + self.symtable.is_nested } }