diff --git a/Lib/shlex.py b/Lib/shlex.py index 4801a6c1d47bd9..595ea890433b4a 100644 --- a/Lib/shlex.py +++ b/Lib/shlex.py @@ -22,6 +22,10 @@ def __init__(self, instream=None, infile=None, posix=False, punctuation_chars=False): if isinstance(instream, str): instream = StringIO(instream) + elif isinstance(instream, bytes): + # convert byte instreams to string + instream = StringIO(instream.decode("ascii", "surrogateescape")) + if instream is not None: self.instream = instream self.infile = infile @@ -312,26 +316,78 @@ def split(s, comments=False, posix=True): lex.whitespace_split = True if not comments: lex.commenters = '' - return list(lex) + + if isinstance(s, bytes): + return [i.encode("ascii") for i in lex] + else: + return list(lex) def join(split_command): """Return a shell-escaped string from *split_command*.""" - return ' '.join(quote(arg) for arg in split_command) + if len(split_command) == 0: + return "" + + # Ensure all objects are the same type as the first one, + # otherwise convert and raise a warning + cleaned = [] + warned = False + + for command in split_command: + if not isinstance(command, type(split_command[0])): + # Check if user was warned not to mix types, + # warn otherwise. + if not warned: + import warnings + warnings.warn("All objects passed to join must be of the same type." + " Converting all to {}.".format(type(split_command[0]).__name__), + BytesWarning, stacklevel=2) + warned = True -_find_unsafe = re.compile(r'[^\w@%+=:,./-]', re.ASCII).search + # Convert object to opposite type + if isinstance(command, bytes): + command = command.decode("ascii", "surrogateescape") + else: + command = command.encode("ascii") + + cleaned.append(command) + + # Return a joined string or bytes object + if isinstance(cleaned[0], bytes): + return b' '.join(quote(arg) for arg in cleaned) + else: + return ' '.join(quote(arg) for arg in cleaned) + + +_find_unsafe_string = re.compile(r'[^\w@%+=:,./-]', re.ASCII).search +_find_unsafe_bytes = re.compile(rb'[^\w@%+=:,./-]').search def quote(s): """Return a shell-escaped version of the string *s*.""" - if not s: - return "''" - if _find_unsafe(s) is None: - return s - - # use single quotes, and put single quotes into double quotes - # the string $'b is then quoted as '$'"'"'b' - return "'" + s.replace("'", "'\"'\"'") + "'" + # determine if s is bytes or string object, + # and check for unsafe characters + if isinstance(s, bytes): + if not s: + return b"''" + + if _find_unsafe_bytes(s) is None: + return s + + # use single quotes, and put single quotes into double quotes + # the string $'b is then quoted as '$'"'"'b' + return b"'" + s.replace(b"'", b"'\"'\"'") + b"'" + + else: + if not s: + return "''" + + if _find_unsafe_string(s) is None: + return s + + # use single quotes, and put single quotes into double quotes + # the string $'b is then quoted as '$'"'"'b' + return "'" + s.replace("'", "'\"'\"'") + "'" def _print_tokens(lexer): diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py index 3081a785204edc..5b3a5ce1945c38 100644 --- a/Lib/test/test_shlex.py +++ b/Lib/test/test_shlex.py @@ -171,6 +171,10 @@ def testSplitPosix(self): """Test data splitting with posix parser""" self.splitTest(self.posix_data, comments=True) + def testSplitBytes(self): + """Test byte objects splitting""" + self.assertEqual(shlex.split(b"split words"), [b"split", b"words"]) + def testCompat(self): """Test compatibility interface""" for i in range(len(self.data)): @@ -339,6 +343,23 @@ def testQuote(self): self.assertEqual(shlex.quote("test%s'name'" % u), "'test%s'\"'\"'name'\"'\"''" % u) + def testQuoteBytes(self): + """Test quoting of byte objects""" + # Copied from testQuote + safeunquoted = string.ascii_letters + string.digits + '@%_-+=:,./' + unicode_sample = '\xe9\xe0\xdf' # e + acute accent, a + grave, sharp s + unsafe = '"`$\\!' + unicode_sample + + self.assertEqual(shlex.quote(b''), b"''") + self.assertEqual(shlex.quote(safeunquoted.encode("ascii")), safeunquoted.encode("ascii")) + self.assertEqual(shlex.quote(b'test file name'), b"'test file name'") + for u in unsafe: + self.assertEqual(shlex.quote(('test%sname' % u).encode("utf-8")), + ("'test%sname'" % u).encode("utf-8")) + for u in unsafe: + self.assertEqual(shlex.quote(("test%s'name'" % u).encode("utf-8")), + ("'test%s'\"'\"'name'\"'\"''" % u).encode("utf-8")) + def testJoin(self): for split_command, command in [ (['a ', 'b'], "'a ' b"), @@ -350,6 +371,9 @@ def testJoin(self): joined = shlex.join(split_command) self.assertEqual(joined, command) + def testEmptyJoin(self): + self.assertEqual(shlex.join([]), "") + def testJoinRoundtrip(self): all_data = self.data + self.posix_data for command, *split_command in all_data: @@ -358,6 +382,46 @@ def testJoinRoundtrip(self): resplit = shlex.split(joined) self.assertEqual(split_command, resplit) + def testJoinBytes(self): + self.assertEqual(shlex.join([b"Join", b"me"]), b"Join me") + self.assertEqual(shlex.join([b"Just_me"]), b"Just_me") + + def testJoinBytesAndStrings(self): + """Test join can handle combinations of string and byte objects""" + # String then bytes + with self.assertWarns(BytesWarning, + msg="All objects passed to join must be of the same type." + "Converting all to str."): + self.assertEqual(shlex.join(["str_object", b"byte_object"]), "str_object byte_object") + + # Bytes then string + with self.assertWarns(BytesWarning, + msg="All objects passed to join must be of the same type." + "Converting all to bytes."): + self.assertEqual(shlex.join([b"byte_object", "str_object"]), b"byte_object str_object") + + # Random combination + with self.assertWarns(BytesWarning): + import random + + words = "This is a fully formed sentence, to test the join functionality." + + new_words = [] + for word in words.split(" "): + if bool(random.randint(0, 1)): + new_words.append(word.encode("ascii")) + else: + new_words.append(word) + + # ensure at least one bytes object to raise warning + if isinstance(new_words[2], str): + new_words[2] = new_words[2].encode("ascii") + + if isinstance(new_words[0], bytes): + self.assertEqual(shlex.join(new_words), words.encode("ascii")) + else: + self.assertEqual(shlex.join(new_words), words) + def testPunctuationCharsReadOnly(self): punctuation_chars = "/|$%^" shlex_instance = shlex.shlex(punctuation_chars=punctuation_chars) diff --git a/Misc/NEWS.d/next/Library/2020-10-12-03-46-25.bpo-25567.xgfgij.rst b/Misc/NEWS.d/next/Library/2020-10-12-03-46-25.bpo-25567.xgfgij.rst new file mode 100644 index 00000000000000..b10dc7b7ec2067 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-10-12-03-46-25.bpo-25567.xgfgij.rst @@ -0,0 +1 @@ +Add support for bytes objects in the shlex module.