Skip to content

Remember the source of interactively defined functions #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions bpython/patch_linecache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import linecache

class BPythonLinecache(dict):
"""Replaces the cache dict in the standard-library linecache module,
to also remember (in an unerasable way) bpython console input."""

def __init__(self, *args, **kwargs):
super(BPythonLinecache, self).__init__(*args, **kwargs)
self.bpython_history = []

def is_bpython_filename(self, fname):
try:
return fname.startswith('<bpython-input-')
except AttributeError:
# In case the key isn't a string
return False

def get_bpython_history(self, key):
"""Given a filename provided by remember_bpython_input,
returns the associated source string."""
try:
idx = int(key.split('-')[2][:-1])
return self.bpython_history[idx]
except (IndexError, ValueError):
raise KeyError

def remember_bpython_input(self, source):
"""Remembers a string of source code, and returns
a fake filename to use to retrieve it later."""
filename = '<bpython-input-%s>' % len(self.bpython_history)
self.bpython_history.append((len(source), None,
source.splitlines(True), filename))
return filename

def __getitem__(self, key):
if self.is_bpython_filename(key):
return self.get_bpython_history(key)
return super(BPythonLinecache, self).__getitem__(key)

def __contains__(self, key):
if self.is_bpython_filename(key):
try:
self.get_bpython_history(key)
return True
except KeyError:
return False
return super(BPythonLinecache, self).__contains__(key)

def __delitem__(self, key):
if not self.is_bpython_filename(key):
return super(BPythonLinecache, self).__delitem__(key)

def _bpython_clear_linecache():
try:
bpython_history = linecache.cache.bpython_history
except AttributeError:
bpython_history = []
linecache.cache = BPythonLinecache()
linecache.cache.bpython_history = bpython_history

# Monkey-patch the linecache module so that we're able
# to hold our command history there and have it persist
linecache.cache = BPythonLinecache(linecache.cache)
linecache.clearcache = _bpython_clear_linecache

def filename_for_console_input(code_string):
"""Remembers a string of source code, and returns
a fake filename to use to retrieve it later."""
try:
return linecache.cache.remember_bpython_input(code_string)
except AttributeError:
# If someone else has patched linecache.cache, better for code to
# simply be unavailable to inspect.getsource() than to raise
# an exception.
return '<input>'
5 changes: 4 additions & 1 deletion bpython/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from bpython.formatter import Parenthesis
from bpython.history import History
from bpython.paste import PasteHelper, PastePinnwand, PasteFailed
from bpython.patch_linecache import filename_for_console_input
from bpython.translations import _, ngettext


Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(self, locals=None, encoding=None):
def reset_running_time(self):
self.running_time = 0

def runsource(self, source, filename='<input>', symbol='single',
def runsource(self, source, filename=None, symbol='single',
encode=True):
"""Execute Python code.

Expand All @@ -104,6 +105,8 @@ def runsource(self, source, filename='<input>', symbol='single',
if not py3 and encode:
source = u'# coding: %s\n%s' % (self.encoding, source)
source = source.encode(self.encoding)
if filename is None:
filename = filename_for_console_input(source)
with self.timer:
return code.InteractiveInterpreter.runsource(self, source,
filename, symbol)
Expand Down
23 changes: 18 additions & 5 deletions bpython/test/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import unicode_literals

import linecache
import sys

try:
Expand All @@ -17,6 +18,10 @@

pypy = 'PyPy' in sys.version

def _last_console_filename():
"""Returns the last 'filename' used for console input
(as will be displayed in a traceback)."""
return '<bpython-input-%s>' % (len(linecache.cache.bpython_history) - 1)

class TestInterpreter(unittest.TestCase):
def test_syntaxerror(self):
Expand All @@ -30,11 +35,11 @@ def append_to_a(message):
i.runsource('1.1.1.1')

if pypy:
expected = ' File ' + green('"<input>"') + ', line ' + \
expected = ' File ' + green('"%s"' % _last_console_filename()) + ', line ' + \
bold(magenta('1')) + '\n 1.1.1.1\n ^\n' + \
bold(red('SyntaxError')) + ': ' + cyan('invalid syntax') + '\n'
else:
expected = ' File ' + green('"<input>"') + ', line ' + \
expected = ' File ' + green('"%s"' % _last_console_filename()) + ', line ' + \
bold(magenta('1')) + '\n 1.1.1.1\n ^\n' + \
bold(red('SyntaxError')) + ': ' + cyan('invalid syntax') + '\n'

Expand All @@ -56,16 +61,16 @@ def f():
def g():
return f()

i.runsource('g()')
i.runsource('g()', encode=False)

if pypy:
global_not_found = "global name 'g' is not defined"
else:
global_not_found = "name 'g' is not defined"

expected = 'Traceback (most recent call last):\n File ' + \
green('"<input>"') + ', line ' + bold(magenta('1')) + ', in ' + \
cyan('<module>') + '\n' + bold(red('NameError')) + ': ' + \
green('"%s"' % _last_console_filename()) + ', line ' + bold(magenta('1')) + ', in ' + \
cyan('<module>') + '\n g()\n' + bold(red('NameError')) + ': ' + \
cyan(global_not_found) + '\n'

self.assertMultiLineEqual(str(plain('').join(a)), str(expected))
Expand Down Expand Up @@ -106,3 +111,11 @@ def test_runsource_unicode(self):
i.runsource("a = u'\xfe'", encode=True)
self.assertIsInstance(i.locals['a'], type(u''))
self.assertEqual(i.locals['a'], u"\xfe")

def test_getsource_works_on_interactively_defined_functions(self):
source = 'def foo(x):\n return x + 1\n'
i = interpreter.Interp()
i.runsource(source)
import inspect
inspected_source = inspect.getsource(i.locals['foo'])
self.assertEquals(inspected_source, source)