diff --git a/bpython/patch_linecache.py b/bpython/patch_linecache.py new file mode 100644 index 000000000..9bd935b50 --- /dev/null +++ b/bpython/patch_linecache.py @@ -0,0 +1,75 @@ +import linecache + +class BPythonLinecache(dict): + """Replaces the cache dict in the standard-library linecache module, + to also remember (in an unerasable way) bpython console input.""" + + def __init__(self, *args, **kwargs): + super(BPythonLinecache, self).__init__(*args, **kwargs) + self.bpython_history = [] + + def is_bpython_filename(self, fname): + try: + return fname.startswith('' diff --git a/bpython/repl.py b/bpython/repl.py index 349cdc00b..877bafc53 100644 --- a/bpython/repl.py +++ b/bpython/repl.py @@ -49,6 +49,7 @@ from bpython.formatter import Parenthesis from bpython.history import History from bpython.paste import PasteHelper, PastePinnwand, PasteFailed +from bpython.patch_linecache import filename_for_console_input from bpython.translations import _, ngettext @@ -94,7 +95,7 @@ def __init__(self, locals=None, encoding=None): def reset_running_time(self): self.running_time = 0 - def runsource(self, source, filename='', symbol='single', + def runsource(self, source, filename=None, symbol='single', encode=True): """Execute Python code. @@ -104,6 +105,8 @@ def runsource(self, source, filename='', symbol='single', if not py3 and encode: source = u'# coding: %s\n%s' % (self.encoding, source) source = source.encode(self.encoding) + if filename is None: + filename = filename_for_console_input(source) with self.timer: return code.InteractiveInterpreter.runsource(self, source, filename, symbol) diff --git a/bpython/test/test_interpreter.py b/bpython/test/test_interpreter.py index a203accbc..43118096f 100644 --- a/bpython/test/test_interpreter.py +++ b/bpython/test/test_interpreter.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals +import linecache import sys try: @@ -17,6 +18,10 @@ pypy = 'PyPy' in sys.version +def _last_console_filename(): + """Returns the last 'filename' used for console input + (as will be displayed in a traceback).""" + return '' % (len(linecache.cache.bpython_history) - 1) class TestInterpreter(unittest.TestCase): def test_syntaxerror(self): @@ -30,11 +35,11 @@ def append_to_a(message): i.runsource('1.1.1.1') if pypy: - expected = ' File ' + green('""') + ', line ' + \ + expected = ' File ' + green('"%s"' % _last_console_filename()) + ', line ' + \ bold(magenta('1')) + '\n 1.1.1.1\n ^\n' + \ bold(red('SyntaxError')) + ': ' + cyan('invalid syntax') + '\n' else: - expected = ' File ' + green('""') + ', line ' + \ + expected = ' File ' + green('"%s"' % _last_console_filename()) + ', line ' + \ bold(magenta('1')) + '\n 1.1.1.1\n ^\n' + \ bold(red('SyntaxError')) + ': ' + cyan('invalid syntax') + '\n' @@ -56,7 +61,7 @@ def f(): def g(): return f() - i.runsource('g()') + i.runsource('g()', encode=False) if pypy: global_not_found = "global name 'g' is not defined" @@ -64,8 +69,8 @@ def g(): global_not_found = "name 'g' is not defined" expected = 'Traceback (most recent call last):\n File ' + \ - green('""') + ', line ' + bold(magenta('1')) + ', in ' + \ - cyan('') + '\n' + bold(red('NameError')) + ': ' + \ + green('"%s"' % _last_console_filename()) + ', line ' + bold(magenta('1')) + ', in ' + \ + cyan('') + '\n g()\n' + bold(red('NameError')) + ': ' + \ cyan(global_not_found) + '\n' self.assertMultiLineEqual(str(plain('').join(a)), str(expected)) @@ -106,3 +111,11 @@ def test_runsource_unicode(self): i.runsource("a = u'\xfe'", encode=True) self.assertIsInstance(i.locals['a'], type(u'')) self.assertEqual(i.locals['a'], u"\xfe") + + def test_getsource_works_on_interactively_defined_functions(self): + source = 'def foo(x):\n return x + 1\n' + i = interpreter.Interp() + i.runsource(source) + import inspect + inspected_source = inspect.getsource(i.locals['foo']) + self.assertEquals(inspected_source, source)