diff --git a/Lib/encodings/aliases.py b/Lib/encodings/aliases.py index d85afd6d5c..6a5ca046b5 100644 --- a/Lib/encodings/aliases.py +++ b/Lib/encodings/aliases.py @@ -209,6 +209,7 @@ 'ms932' : 'cp932', 'mskanji' : 'cp932', 'ms_kanji' : 'cp932', + 'windows_31j' : 'cp932', # cp949 codec '949' : 'cp949', diff --git a/Lib/encodings/idna.py b/Lib/encodings/idna.py index 5396047a7f..0c90b4c9fe 100644 --- a/Lib/encodings/idna.py +++ b/Lib/encodings/idna.py @@ -11,7 +11,7 @@ sace_prefix = "xn--" # This assumes query strings, so AllowUnassigned is true -def nameprep(label): +def nameprep(label): # type: (str) -> str # Map newlabel = [] for c in label: @@ -25,7 +25,7 @@ def nameprep(label): label = unicodedata.normalize("NFKC", label) # Prohibit - for c in label: + for i, c in enumerate(label): if stringprep.in_table_c12(c) or \ stringprep.in_table_c22(c) or \ stringprep.in_table_c3(c) or \ @@ -35,7 +35,7 @@ def nameprep(label): stringprep.in_table_c7(c) or \ stringprep.in_table_c8(c) or \ stringprep.in_table_c9(c): - raise UnicodeError("Invalid character %r" % c) + raise UnicodeEncodeError("idna", label, i, i+1, f"Invalid character {c!r}") # Check bidi RandAL = [stringprep.in_table_d1(x) for x in label] @@ -46,29 +46,38 @@ def nameprep(label): # This is table C.8, which was already checked # 2) If a string contains any RandALCat character, the string # MUST NOT contain any LCat character. - if any(stringprep.in_table_d2(x) for x in label): - raise UnicodeError("Violation of BIDI requirement 2") + for i, x in enumerate(label): + if stringprep.in_table_d2(x): + raise UnicodeEncodeError("idna", label, i, i+1, + "Violation of BIDI requirement 2") # 3) If a string contains any RandALCat character, a # RandALCat character MUST be the first character of the # string, and a RandALCat character MUST be the last # character of the string. - if not RandAL[0] or not RandAL[-1]: - raise UnicodeError("Violation of BIDI requirement 3") + if not RandAL[0]: + raise UnicodeEncodeError("idna", label, 0, 1, + "Violation of BIDI requirement 3") + if not RandAL[-1]: + raise UnicodeEncodeError("idna", label, len(label)-1, len(label), + "Violation of BIDI requirement 3") return label -def ToASCII(label): +def ToASCII(label): # type: (str) -> bytes try: # Step 1: try ASCII - label = label.encode("ascii") - except UnicodeError: + label_ascii = label.encode("ascii") + except UnicodeEncodeError: pass else: # Skip to step 3: UseSTD3ASCIIRules is false, so # Skip to step 8. - if 0 < len(label) < 64: - return label - raise UnicodeError("label empty or too long") + if 0 < len(label_ascii) < 64: + return label_ascii + if len(label) == 0: + raise UnicodeEncodeError("idna", label, 0, 1, "label empty") + else: + raise UnicodeEncodeError("idna", label, 0, len(label), "label too long") # Step 2: nameprep label = nameprep(label) @@ -76,29 +85,34 @@ def ToASCII(label): # Step 3: UseSTD3ASCIIRules is false # Step 4: try ASCII try: - label = label.encode("ascii") - except UnicodeError: + label_ascii = label.encode("ascii") + except UnicodeEncodeError: pass else: # Skip to step 8. if 0 < len(label) < 64: - return label - raise UnicodeError("label empty or too long") + return label_ascii + if len(label) == 0: + raise UnicodeEncodeError("idna", label, 0, 1, "label empty") + else: + raise UnicodeEncodeError("idna", label, 0, len(label), "label too long") # Step 5: Check ACE prefix - if label.startswith(sace_prefix): - raise UnicodeError("Label starts with ACE prefix") + if label.lower().startswith(sace_prefix): + raise UnicodeEncodeError( + "idna", label, 0, len(sace_prefix), "Label starts with ACE prefix") # Step 6: Encode with PUNYCODE - label = label.encode("punycode") + label_ascii = label.encode("punycode") # Step 7: Prepend ACE prefix - label = ace_prefix + label + label_ascii = ace_prefix + label_ascii # Step 8: Check size - if 0 < len(label) < 64: - return label - raise UnicodeError("label empty or too long") + # do not check for empty as we prepend ace_prefix. + if len(label_ascii) < 64: + return label_ascii + raise UnicodeEncodeError("idna", label, 0, len(label), "label too long") def ToUnicode(label): if len(label) > 1024: @@ -110,7 +124,9 @@ def ToUnicode(label): # per https://www.rfc-editor.org/rfc/rfc3454#section-3.1 while still # preventing us from wasting time decoding a big thing that'll just # hit the actual <= 63 length limit in Step 6. - raise UnicodeError("label way too long") + if isinstance(label, str): + label = label.encode("utf-8", errors="backslashreplace") + raise UnicodeDecodeError("idna", label, 0, len(label), "label way too long") # Step 1: Check for ASCII if isinstance(label, bytes): pure_ascii = True @@ -118,25 +134,32 @@ def ToUnicode(label): try: label = label.encode("ascii") pure_ascii = True - except UnicodeError: + except UnicodeEncodeError: pure_ascii = False if not pure_ascii: + assert isinstance(label, str) # Step 2: Perform nameprep label = nameprep(label) # It doesn't say this, but apparently, it should be ASCII now try: label = label.encode("ascii") - except UnicodeError: - raise UnicodeError("Invalid character in IDN label") + except UnicodeEncodeError as exc: + raise UnicodeEncodeError("idna", label, exc.start, exc.end, + "Invalid character in IDN label") # Step 3: Check for ACE prefix - if not label.startswith(ace_prefix): + assert isinstance(label, bytes) + if not label.lower().startswith(ace_prefix): return str(label, "ascii") # Step 4: Remove ACE prefix label1 = label[len(ace_prefix):] # Step 5: Decode using PUNYCODE - result = label1.decode("punycode") + try: + result = label1.decode("punycode") + except UnicodeDecodeError as exc: + offset = len(ace_prefix) + raise UnicodeDecodeError("idna", label, offset+exc.start, offset+exc.end, exc.reason) # Step 6: Apply ToASCII label2 = ToASCII(result) @@ -144,7 +167,8 @@ def ToUnicode(label): # Step 7: Compare the result of step 6 with the one of step 3 # label2 will already be in lower case. if str(label, "ascii").lower() != str(label2, "ascii"): - raise UnicodeError("IDNA does not round-trip", label, label2) + raise UnicodeDecodeError("idna", label, 0, len(label), + f"IDNA does not round-trip, '{label!r}' != '{label2!r}'") # Step 8: return the result of step 5 return result @@ -156,7 +180,7 @@ def encode(self, input, errors='strict'): if errors != 'strict': # IDNA is quite clear that implementations must be strict - raise UnicodeError("unsupported error handling "+errors) + raise UnicodeError(f"Unsupported error handling: {errors}") if not input: return b'', 0 @@ -168,11 +192,16 @@ def encode(self, input, errors='strict'): else: # ASCII name: fast path labels = result.split(b'.') - for label in labels[:-1]: - if not (0 < len(label) < 64): - raise UnicodeError("label empty or too long") - if len(labels[-1]) >= 64: - raise UnicodeError("label too long") + for i, label in enumerate(labels[:-1]): + if len(label) == 0: + offset = sum(len(l) for l in labels[:i]) + i + raise UnicodeEncodeError("idna", input, offset, offset+1, + "label empty") + for i, label in enumerate(labels): + if len(label) >= 64: + offset = sum(len(l) for l in labels[:i]) + i + raise UnicodeEncodeError("idna", input, offset, offset+len(label), + "label too long") return result, len(input) result = bytearray() @@ -182,17 +211,27 @@ def encode(self, input, errors='strict'): del labels[-1] else: trailing_dot = b'' - for label in labels: + for i, label in enumerate(labels): if result: # Join with U+002E result.extend(b'.') - result.extend(ToASCII(label)) + try: + result.extend(ToASCII(label)) + except (UnicodeEncodeError, UnicodeDecodeError) as exc: + offset = sum(len(l) for l in labels[:i]) + i + raise UnicodeEncodeError( + "idna", + input, + offset + exc.start, + offset + exc.end, + exc.reason, + ) return bytes(result+trailing_dot), len(input) def decode(self, input, errors='strict'): if errors != 'strict': - raise UnicodeError("Unsupported error handling "+errors) + raise UnicodeError(f"Unsupported error handling: {errors}") if not input: return "", 0 @@ -202,7 +241,7 @@ def decode(self, input, errors='strict'): # XXX obviously wrong, see #3232 input = bytes(input) - if ace_prefix not in input: + if ace_prefix not in input.lower(): # Fast path try: return input.decode('ascii'), len(input) @@ -218,8 +257,15 @@ def decode(self, input, errors='strict'): trailing_dot = '' result = [] - for label in labels: - result.append(ToUnicode(label)) + for i, label in enumerate(labels): + try: + u_label = ToUnicode(label) + except (UnicodeEncodeError, UnicodeDecodeError) as exc: + offset = sum(len(x) for x in labels[:i]) + len(labels[:i]) + raise UnicodeDecodeError( + "idna", input, offset+exc.start, offset+exc.end, exc.reason) + else: + result.append(u_label) return ".".join(result)+trailing_dot, len(input) @@ -227,7 +273,7 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder): def _buffer_encode(self, input, errors, final): if errors != 'strict': # IDNA is quite clear that implementations must be strict - raise UnicodeError("unsupported error handling "+errors) + raise UnicodeError(f"Unsupported error handling: {errors}") if not input: return (b'', 0) @@ -251,7 +297,16 @@ def _buffer_encode(self, input, errors, final): # Join with U+002E result.extend(b'.') size += 1 - result.extend(ToASCII(label)) + try: + result.extend(ToASCII(label)) + except (UnicodeEncodeError, UnicodeDecodeError) as exc: + raise UnicodeEncodeError( + "idna", + input, + size + exc.start, + size + exc.end, + exc.reason, + ) size += len(label) result += trailing_dot @@ -261,7 +316,7 @@ def _buffer_encode(self, input, errors, final): class IncrementalDecoder(codecs.BufferedIncrementalDecoder): def _buffer_decode(self, input, errors, final): if errors != 'strict': - raise UnicodeError("Unsupported error handling "+errors) + raise UnicodeError(f"Unsupported error handling: {errors}") if not input: return ("", 0) @@ -271,7 +326,11 @@ def _buffer_decode(self, input, errors, final): labels = dots.split(input) else: # Must be ASCII string - input = str(input, "ascii") + try: + input = str(input, "ascii") + except (UnicodeEncodeError, UnicodeDecodeError) as exc: + raise UnicodeDecodeError("idna", input, + exc.start, exc.end, exc.reason) labels = input.split(".") trailing_dot = '' @@ -288,7 +347,18 @@ def _buffer_decode(self, input, errors, final): result = [] size = 0 for label in labels: - result.append(ToUnicode(label)) + try: + u_label = ToUnicode(label) + except (UnicodeEncodeError, UnicodeDecodeError) as exc: + raise UnicodeDecodeError( + "idna", + input.encode("ascii", errors="backslashreplace"), + size + exc.start, + size + exc.end, + exc.reason, + ) + else: + result.append(u_label) if size: size += 1 size += len(label) diff --git a/Lib/encodings/palmos.py b/Lib/encodings/palmos.py index c506d65452..df164ca5b9 100644 --- a/Lib/encodings/palmos.py +++ b/Lib/encodings/palmos.py @@ -201,7 +201,7 @@ def getregentry(): '\u02dc' # 0x98 -> SMALL TILDE '\u2122' # 0x99 -> TRADE MARK SIGN '\u0161' # 0x9A -> LATIN SMALL LETTER S WITH CARON - '\x9b' # 0x9B -> + '\u203a' # 0x9B -> SINGLE RIGHT-POINTING ANGLE QUOTATION MARK '\u0153' # 0x9C -> LATIN SMALL LIGATURE OE '\x9d' # 0x9D -> '\x9e' # 0x9E -> diff --git a/Lib/encodings/punycode.py b/Lib/encodings/punycode.py index 1c57264470..4622fc8c92 100644 --- a/Lib/encodings/punycode.py +++ b/Lib/encodings/punycode.py @@ -1,4 +1,4 @@ -""" Codec for the Punicode encoding, as specified in RFC 3492 +""" Codec for the Punycode encoding, as specified in RFC 3492 Written by Martin v. Löwis. """ @@ -131,10 +131,11 @@ def decode_generalized_number(extended, extpos, bias, errors): j = 0 while 1: try: - char = ord(extended[extpos]) + char = extended[extpos] except IndexError: if errors == "strict": - raise UnicodeError("incomplete punicode string") + raise UnicodeDecodeError("punycode", extended, extpos, extpos+1, + "incomplete punycode string") return extpos + 1, None extpos += 1 if 0x41 <= char <= 0x5A: # A-Z @@ -142,8 +143,8 @@ def decode_generalized_number(extended, extpos, bias, errors): elif 0x30 <= char <= 0x39: digit = char - 22 # 0x30-26 elif errors == "strict": - raise UnicodeError("Invalid extended code point '%s'" - % extended[extpos-1]) + raise UnicodeDecodeError("punycode", extended, extpos-1, extpos, + f"Invalid extended code point '{extended[extpos-1]}'") else: return extpos, None t = T(j, bias) @@ -155,11 +156,14 @@ def decode_generalized_number(extended, extpos, bias, errors): def insertion_sort(base, extended, errors): - """3.2 Insertion unsort coding""" + """3.2 Insertion sort coding""" + # This function raises UnicodeDecodeError with position in the extended. + # Caller should add the offset. char = 0x80 pos = -1 bias = 72 extpos = 0 + while extpos < len(extended): newpos, delta = decode_generalized_number(extended, extpos, bias, errors) @@ -171,7 +175,9 @@ def insertion_sort(base, extended, errors): char += pos // (len(base) + 1) if char > 0x10FFFF: if errors == "strict": - raise UnicodeError("Invalid character U+%x" % char) + raise UnicodeDecodeError( + "punycode", extended, pos-1, pos, + f"Invalid character U+{char:x}") char = ord('?') pos = pos % (len(base) + 1) base = base[:pos] + chr(char) + base[pos:] @@ -187,11 +193,21 @@ def punycode_decode(text, errors): pos = text.rfind(b"-") if pos == -1: base = "" - extended = str(text, "ascii").upper() + extended = text.upper() else: - base = str(text[:pos], "ascii", errors) - extended = str(text[pos+1:], "ascii").upper() - return insertion_sort(base, extended, errors) + try: + base = str(text[:pos], "ascii", errors) + except UnicodeDecodeError as exc: + raise UnicodeDecodeError("ascii", text, exc.start, exc.end, + exc.reason) from None + extended = text[pos+1:].upper() + try: + return insertion_sort(base, extended, errors) + except UnicodeDecodeError as exc: + offset = pos + 1 + raise UnicodeDecodeError("punycode", text, + offset+exc.start, offset+exc.end, + exc.reason) from None ### Codec APIs @@ -203,7 +219,7 @@ def encode(self, input, errors='strict'): def decode(self, input, errors='strict'): if errors not in ('strict', 'replace', 'ignore'): - raise UnicodeError("Unsupported error handling "+errors) + raise UnicodeError(f"Unsupported error handling: {errors}") res = punycode_decode(input, errors) return res, len(input) @@ -214,7 +230,7 @@ def encode(self, input, final=False): class IncrementalDecoder(codecs.IncrementalDecoder): def decode(self, input, final=False): if self.errors not in ('strict', 'replace', 'ignore'): - raise UnicodeError("Unsupported error handling "+self.errors) + raise UnicodeError(f"Unsupported error handling: {self.errors}") return punycode_decode(input, self.errors) class StreamWriter(Codec,codecs.StreamWriter): diff --git a/Lib/encodings/undefined.py b/Lib/encodings/undefined.py index 4690288355..082771e1c8 100644 --- a/Lib/encodings/undefined.py +++ b/Lib/encodings/undefined.py @@ -1,6 +1,6 @@ """ Python 'undefined' Codec - This codec will always raise a ValueError exception when being + This codec will always raise a UnicodeError exception when being used. It is intended for use by the site.py file to switch off automatic string to Unicode coercion. diff --git a/Lib/encodings/utf_16.py b/Lib/encodings/utf_16.py index c61248242b..d3b9980026 100644 --- a/Lib/encodings/utf_16.py +++ b/Lib/encodings/utf_16.py @@ -64,7 +64,7 @@ def _buffer_decode(self, input, errors, final): elif byteorder == 1: self.decoder = codecs.utf_16_be_decode elif consumed >= 2: - raise UnicodeError("UTF-16 stream does not start with BOM") + raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not start with BOM") return (output, consumed) return self.decoder(input, self.errors, final) @@ -138,7 +138,7 @@ def decode(self, input, errors='strict'): elif byteorder == 1: self.decode = codecs.utf_16_be_decode elif consumed>=2: - raise UnicodeError("UTF-16 stream does not start with BOM") + raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not start with BOM") return (object, consumed) ### encodings module API diff --git a/Lib/encodings/utf_32.py b/Lib/encodings/utf_32.py index cdf84d1412..1924bedbb7 100644 --- a/Lib/encodings/utf_32.py +++ b/Lib/encodings/utf_32.py @@ -59,7 +59,7 @@ def _buffer_decode(self, input, errors, final): elif byteorder == 1: self.decoder = codecs.utf_32_be_decode elif consumed >= 4: - raise UnicodeError("UTF-32 stream does not start with BOM") + raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not start with BOM") return (output, consumed) return self.decoder(input, self.errors, final) @@ -132,8 +132,8 @@ def decode(self, input, errors='strict'): self.decode = codecs.utf_32_le_decode elif byteorder == 1: self.decode = codecs.utf_32_be_decode - elif consumed>=4: - raise UnicodeError("UTF-32 stream does not start with BOM") + elif consumed >= 4: + raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not start with BOM") return (object, consumed) ### encodings module API diff --git a/Lib/fnmatch.py b/Lib/fnmatch.py index fee59bf73f..73acb1fe8d 100644 --- a/Lib/fnmatch.py +++ b/Lib/fnmatch.py @@ -16,12 +16,6 @@ __all__ = ["filter", "fnmatch", "fnmatchcase", "translate"] -# Build a thread-safe incrementing counter to help create unique regexp group -# names across calls. -from itertools import count -_nextgroupnum = count().__next__ -del count - def fnmatch(name, pat): """Test whether FILENAME matches PATTERN. @@ -41,7 +35,7 @@ def fnmatch(name, pat): pat = os.path.normcase(pat) return fnmatchcase(name, pat) -@functools.lru_cache(maxsize=256, typed=True) +@functools.lru_cache(maxsize=32768, typed=True) def _compile_pattern(pat): if isinstance(pat, bytes): pat_str = str(pat, 'ISO-8859-1') @@ -84,6 +78,11 @@ def translate(pat): """ STAR = object() + parts = _translate(pat, STAR, '.') + return _join_translated_parts(parts, STAR) + + +def _translate(pat, STAR, QUESTION_MARK): res = [] add = res.append i, n = 0, len(pat) @@ -95,7 +94,7 @@ def translate(pat): if (not res) or res[-1] is not STAR: add(STAR) elif c == '?': - add('.') + add(QUESTION_MARK) elif c == '[': j = i if j < n and pat[j] == '!': @@ -152,9 +151,11 @@ def translate(pat): else: add(re.escape(c)) assert i == n + return res + +def _join_translated_parts(inp, STAR): # Deal with STARs. - inp = res res = [] add = res.append i, n = 0, len(inp) @@ -165,17 +166,10 @@ def translate(pat): # Now deal with STAR fixed STAR fixed ... # For an interior `STAR fixed` pairing, we want to do a minimal # .*? match followed by `fixed`, with no possibility of backtracking. - # We can't spell that directly, but can trick it into working by matching - # .*?fixed - # in a lookahead assertion, save the matched part in a group, then - # consume that group via a backreference. If the overall match fails, - # the lookahead assertion won't try alternatives. So the translation is: - # (?=(?P.*?fixed))(?P=name) - # Group names are created as needed: g0, g1, g2, ... - # The numbers are obtained from _nextgroupnum() to ensure they're unique - # across calls and across threads. This is because people rely on the - # undocumented ability to join multiple translate() results together via - # "|" to build large regexps matching "one of many" shell patterns. + # Atomic groups ("(?>...)") allow us to spell that directly. + # Note: people rely on the undocumented ability to join multiple + # translate() results together via "|" to build large regexps matching + # "one of many" shell patterns. while i < n: assert inp[i] is STAR i += 1 @@ -192,8 +186,7 @@ def translate(pat): add(".*") add(fixed) else: - groupnum = _nextgroupnum() - add(f"(?=(?P.*?{fixed}))(?P=g{groupnum})") + add(f"(?>.*?{fixed})") assert i == n res = "".join(res) return fr'(?s:{res})\Z' diff --git a/Lib/glob.py b/Lib/glob.py index 50beef37f4..c506e0e215 100644 --- a/Lib/glob.py +++ b/Lib/glob.py @@ -4,11 +4,14 @@ import os import re import fnmatch +import functools import itertools +import operator import stat import sys -__all__ = ["glob", "iglob", "escape"] + +__all__ = ["glob", "iglob", "escape", "translate"] def glob(pathname, *, root_dir=None, dir_fd=None, recursive=False, include_hidden=False): @@ -104,8 +107,8 @@ def _iglob(pathname, root_dir, dir_fd, recursive, dironly, def _glob1(dirname, pattern, dir_fd, dironly, include_hidden=False): names = _listdir(dirname, dir_fd, dironly) - if include_hidden or not _ishidden(pattern): - names = (x for x in names if include_hidden or not _ishidden(x)) + if not (include_hidden or _ishidden(pattern)): + names = (x for x in names if not _ishidden(x)) return fnmatch.filter(names, pattern) def _glob0(dirname, basename, dir_fd, dironly, include_hidden=False): @@ -119,12 +122,19 @@ def _glob0(dirname, basename, dir_fd, dironly, include_hidden=False): return [basename] return [] -# Following functions are not public but can be used by third-party code. +_deprecated_function_message = ( + "{name} is deprecated and will be removed in Python {remove}. Use " + "glob.glob and pass a directory to its root_dir argument instead." +) def glob0(dirname, pattern): + import warnings + warnings._deprecated("glob.glob0", _deprecated_function_message, remove=(3, 15)) return _glob0(dirname, pattern, None, False) def glob1(dirname, pattern): + import warnings + warnings._deprecated("glob.glob1", _deprecated_function_message, remove=(3, 15)) return _glob1(dirname, pattern, None, False) # This helper function recursively yields relative pathnames inside a literal @@ -249,4 +259,287 @@ def escape(pathname): return drive + pathname +_special_parts = ('', '.', '..') _dir_open_flags = os.O_RDONLY | getattr(os, 'O_DIRECTORY', 0) +_no_recurse_symlinks = object() + + +def translate(pat, *, recursive=False, include_hidden=False, seps=None): + """Translate a pathname with shell wildcards to a regular expression. + + If `recursive` is true, the pattern segment '**' will match any number of + path segments. + + If `include_hidden` is true, wildcards can match path segments beginning + with a dot ('.'). + + If a sequence of separator characters is given to `seps`, they will be + used to split the pattern into segments and match path separators. If not + given, os.path.sep and os.path.altsep (where available) are used. + """ + if not seps: + if os.path.altsep: + seps = (os.path.sep, os.path.altsep) + else: + seps = os.path.sep + escaped_seps = ''.join(map(re.escape, seps)) + any_sep = f'[{escaped_seps}]' if len(seps) > 1 else escaped_seps + not_sep = f'[^{escaped_seps}]' + if include_hidden: + one_last_segment = f'{not_sep}+' + one_segment = f'{one_last_segment}{any_sep}' + any_segments = f'(?:.+{any_sep})?' + any_last_segments = '.*' + else: + one_last_segment = f'[^{escaped_seps}.]{not_sep}*' + one_segment = f'{one_last_segment}{any_sep}' + any_segments = f'(?:{one_segment})*' + any_last_segments = f'{any_segments}(?:{one_last_segment})?' + + results = [] + parts = re.split(any_sep, pat) + last_part_idx = len(parts) - 1 + for idx, part in enumerate(parts): + if part == '*': + results.append(one_segment if idx < last_part_idx else one_last_segment) + elif recursive and part == '**': + if idx < last_part_idx: + if parts[idx + 1] != '**': + results.append(any_segments) + else: + results.append(any_last_segments) + else: + if part: + if not include_hidden and part[0] in '*?': + results.append(r'(?!\.)') + results.extend(fnmatch._translate(part, f'{not_sep}*', not_sep)) + if idx < last_part_idx: + results.append(any_sep) + res = ''.join(results) + return fr'(?s:{res})\Z' + + +@functools.lru_cache(maxsize=512) +def _compile_pattern(pat, sep, case_sensitive, recursive=True): + """Compile given glob pattern to a re.Pattern object (observing case + sensitivity).""" + flags = re.NOFLAG if case_sensitive else re.IGNORECASE + regex = translate(pat, recursive=recursive, include_hidden=True, seps=sep) + return re.compile(regex, flags=flags).match + + +class _Globber: + """Class providing shell-style pattern matching and globbing. + """ + + def __init__(self, sep, case_sensitive, case_pedantic=False, recursive=False): + self.sep = sep + self.case_sensitive = case_sensitive + self.case_pedantic = case_pedantic + self.recursive = recursive + + # Low-level methods + + lstat = operator.methodcaller('lstat') + add_slash = operator.methodcaller('joinpath', '') + + @staticmethod + def scandir(path): + """Emulates os.scandir(), which returns an object that can be used as + a context manager. This method is called by walk() and glob(). + """ + return contextlib.nullcontext(path.iterdir()) + + @staticmethod + def concat_path(path, text): + """Appends text to the given path. + """ + return path.with_segments(path._raw_path + text) + + @staticmethod + def parse_entry(entry): + """Returns the path of an entry yielded from scandir(). + """ + return entry + + # High-level methods + + def compile(self, pat): + return _compile_pattern(pat, self.sep, self.case_sensitive, self.recursive) + + def selector(self, parts): + """Returns a function that selects from a given path, walking and + filtering according to the glob-style pattern parts in *parts*. + """ + if not parts: + return self.select_exists + part = parts.pop() + if self.recursive and part == '**': + selector = self.recursive_selector + elif part in _special_parts: + selector = self.special_selector + elif not self.case_pedantic and magic_check.search(part) is None: + selector = self.literal_selector + else: + selector = self.wildcard_selector + return selector(part, parts) + + def special_selector(self, part, parts): + """Returns a function that selects special children of the given path. + """ + select_next = self.selector(parts) + + def select_special(path, exists=False): + path = self.concat_path(self.add_slash(path), part) + return select_next(path, exists) + return select_special + + def literal_selector(self, part, parts): + """Returns a function that selects a literal descendant of a path. + """ + + # Optimization: consume and join any subsequent literal parts here, + # rather than leaving them for the next selector. This reduces the + # number of string concatenation operations and calls to add_slash(). + while parts and magic_check.search(parts[-1]) is None: + part += self.sep + parts.pop() + + select_next = self.selector(parts) + + def select_literal(path, exists=False): + path = self.concat_path(self.add_slash(path), part) + return select_next(path, exists=False) + return select_literal + + def wildcard_selector(self, part, parts): + """Returns a function that selects direct children of a given path, + filtering by pattern. + """ + + match = None if part == '*' else self.compile(part) + dir_only = bool(parts) + if dir_only: + select_next = self.selector(parts) + + def select_wildcard(path, exists=False): + try: + # We must close the scandir() object before proceeding to + # avoid exhausting file descriptors when globbing deep trees. + with self.scandir(path) as scandir_it: + entries = list(scandir_it) + except OSError: + pass + else: + for entry in entries: + if match is None or match(entry.name): + if dir_only: + try: + if not entry.is_dir(): + continue + except OSError: + continue + entry_path = self.parse_entry(entry) + if dir_only: + yield from select_next(entry_path, exists=True) + else: + yield entry_path + return select_wildcard + + def recursive_selector(self, part, parts): + """Returns a function that selects a given path and all its children, + recursively, filtering by pattern. + """ + # Optimization: consume following '**' parts, which have no effect. + while parts and parts[-1] == '**': + parts.pop() + + # Optimization: consume and join any following non-special parts here, + # rather than leaving them for the next selector. They're used to + # build a regular expression, which we use to filter the results of + # the recursive walk. As a result, non-special pattern segments + # following a '**' wildcard don't require additional filesystem access + # to expand. + follow_symlinks = self.recursive is not _no_recurse_symlinks + if follow_symlinks: + while parts and parts[-1] not in _special_parts: + part += self.sep + parts.pop() + + match = None if part == '**' else self.compile(part) + dir_only = bool(parts) + select_next = self.selector(parts) + + def select_recursive(path, exists=False): + path = self.add_slash(path) + match_pos = len(str(path)) + if match is None or match(str(path), match_pos): + yield from select_next(path, exists) + stack = [path] + while stack: + yield from select_recursive_step(stack, match_pos) + + def select_recursive_step(stack, match_pos): + path = stack.pop() + try: + # We must close the scandir() object before proceeding to + # avoid exhausting file descriptors when globbing deep trees. + with self.scandir(path) as scandir_it: + entries = list(scandir_it) + except OSError: + pass + else: + for entry in entries: + is_dir = False + try: + if entry.is_dir(follow_symlinks=follow_symlinks): + is_dir = True + except OSError: + pass + + if is_dir or not dir_only: + entry_path = self.parse_entry(entry) + if match is None or match(str(entry_path), match_pos): + if dir_only: + yield from select_next(entry_path, exists=True) + else: + # Optimization: directly yield the path if this is + # last pattern part. + yield entry_path + if is_dir: + stack.append(entry_path) + + return select_recursive + + def select_exists(self, path, exists=False): + """Yields the given path, if it exists. + """ + if exists: + # Optimization: this path is already known to exist, e.g. because + # it was returned from os.scandir(), so we skip calling lstat(). + yield path + else: + try: + self.lstat(path) + yield path + except OSError: + pass + + +class _StringGlobber(_Globber): + lstat = staticmethod(os.lstat) + scandir = staticmethod(os.scandir) + parse_entry = operator.attrgetter('path') + concat_path = operator.add + + if os.name == 'nt': + @staticmethod + def add_slash(pathname): + tail = os.path.splitroot(pathname)[2] + if not tail or tail[-1] in '\\/': + return pathname + return f'{pathname}\\' + else: + @staticmethod + def add_slash(pathname): + if not pathname or pathname[-1] == '/': + return pathname + return f'{pathname}/' diff --git a/Lib/io.py b/Lib/io.py index c2812876d3..f0e2fa15d5 100644 --- a/Lib/io.py +++ b/Lib/io.py @@ -46,23 +46,17 @@ "BufferedReader", "BufferedWriter", "BufferedRWPair", "BufferedRandom", "TextIOBase", "TextIOWrapper", "UnsupportedOperation", "SEEK_SET", "SEEK_CUR", "SEEK_END", - "DEFAULT_BUFFER_SIZE", "text_encoding", - "IncrementalNewlineDecoder" - ] + "DEFAULT_BUFFER_SIZE", "text_encoding", "IncrementalNewlineDecoder"] import _io import abc from _io import (DEFAULT_BUFFER_SIZE, BlockingIOError, UnsupportedOperation, - open, open_code, BytesIO, StringIO, BufferedReader, + open, open_code, FileIO, BytesIO, StringIO, BufferedReader, BufferedWriter, BufferedRWPair, BufferedRandom, IncrementalNewlineDecoder, text_encoding, TextIOWrapper) -try: - from _io import FileIO -except ImportError: - pass # Pretend this exception was created here. UnsupportedOperation.__module__ = "io" @@ -87,10 +81,7 @@ class BufferedIOBase(_io._BufferedIOBase, IOBase): class TextIOBase(_io._TextIOBase, IOBase): __doc__ = _io._TextIOBase.__doc__ -try: - RawIOBase.register(FileIO) -except NameError: - pass +RawIOBase.register(FileIO) for klass in (BytesIO, BufferedReader, BufferedWriter, BufferedRandom, BufferedRWPair): diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index 6b0dc09e28..b7e5784b48 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -347,8 +347,6 @@ def func(arg): newcode = code.replace(co_name="func") # Should not raise SystemError self.assertEqual(code, newcode) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_empty_linetable(self): def func(): pass @@ -468,8 +466,6 @@ def f(): # co_positions behavior when info is missing. - # TODO: RUSTPYTHON - @unittest.expectedFailure # @requires_debug_ranges() def test_co_positions_empty_linetable(self): def func(): @@ -480,8 +476,6 @@ def func(): self.assertIsNone(line) self.assertEqual(end_line, new_code.co_firstlineno + 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_code_equality(self): def f(): try: @@ -522,8 +516,6 @@ def test_code_hash_uses_order(self): self.assertNotEqual(c, swapped) self.assertNotEqual(hash(c), hash(swapped)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_code_hash_uses_bytecode(self): c = (lambda x, y: x + y).__code__ d = (lambda x, y: x * y).__code__ @@ -735,8 +727,6 @@ def check_positions(self, func): self.assertEqual(l1, l2) self.assertEqual(len(pos1), len(pos2)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_positions(self): self.check_positions(parse_location_table) self.check_positions(misshappen) @@ -751,8 +741,6 @@ def check_lines(self, func): self.assertEqual(l1, l2) self.assertEqual(len(lines1), len(lines2)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_lines(self): self.check_lines(parse_location_table) self.check_lines(misshappen) diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index f986d85c6d..0d3d8c9e2d 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -762,7 +762,6 @@ def test_only_one_bom(self): f = reader(s) self.assertEqual(f.read(), "spamspam") - @unittest.expectedFailure # TODO: RUSTPYTHON;; UTF-16 stream does not start with BOM def test_badbom(self): s = io.BytesIO(b"\xff\xff") f = codecs.getreader(self.encoding)(s) @@ -1509,7 +1508,6 @@ def test_decode(self): puny = puny.decode("ascii").encode("ascii") self.assertEqual(uni, puny.decode("punycode")) - @unittest.expectedFailure # TODO: RUSTPYTHON; b'Pro\xffprostnemluvesky' != b'Pro\xffprostnemluvesky-uyb24dma41a' def test_decode_invalid(self): testcases = [ (b"xn--w&", "strict", UnicodeDecodeError("punycode", b"", 5, 6, "")), @@ -1694,7 +1692,6 @@ def test_decode_invalid(self): class NameprepTest(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON; UnicodeError: Invalid character '\u1680' def test_nameprep(self): from encodings.idna import nameprep for pos, (orig, prepped) in enumerate(nameprep_tests): @@ -1732,7 +1729,6 @@ class IDNACodecTest(unittest.TestCase): ("あさ.\u034f", UnicodeEncodeError("idna", "あさ.\u034f", 3, 4, "")), ] - @unittest.expectedFailure # TODO: RUSTPYTHON; 'XN--pythn-mua.org.' != 'pythön.org.' def test_builtin_decode(self): self.assertEqual(str(b"python.org", "idna"), "python.org") self.assertEqual(str(b"python.org.", "idna"), "python.org.") @@ -1746,7 +1742,6 @@ def test_builtin_decode(self): self.assertEqual(str(b"bugs.XN--pythn-mua.org.", "idna"), "bugs.pyth\xf6n.org.") - @unittest.expectedFailure # TODO: RUSTPYTHON; 'ascii' != 'idna' def test_builtin_decode_invalid(self): for case, expected in self.invalid_decode_testcases: with self.subTest(case=case, expected=expected): @@ -1764,7 +1759,6 @@ def test_builtin_encode(self): self.assertEqual("pyth\xf6n.org".encode("idna"), b"xn--pythn-mua.org") self.assertEqual("pyth\xf6n.org.".encode("idna"), b"xn--pythn-mua.org.") - @unittest.expectedFailure # TODO: RUSTPYTHON; UnicodeError: label empty or too long def test_builtin_encode_invalid(self): for case, expected in self.invalid_encode_testcases: with self.subTest(case=case, expected=expected): @@ -1776,7 +1770,6 @@ def test_builtin_encode_invalid(self): self.assertEqual(exc.start, expected.start) self.assertEqual(exc.end, expected.end) - @unittest.expectedFailure # TODO: RUSTPYTHON; UnicodeError: label empty or too long def test_builtin_decode_length_limit(self): with self.assertRaisesRegex(UnicodeDecodeError, "way too long"): (b"xn--016c"+b"a"*1100).decode("idna") @@ -1818,7 +1811,6 @@ def test_incremental_decode(self): self.assertEqual(decoder.decode(b"rg."), "org.") self.assertEqual(decoder.decode(b"", True), "") - @unittest.expectedFailure # TODO: RUSTPYTHON; 'ascii' != 'idna' def test_incremental_decode_invalid(self): iterdecode_testcases = [ (b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")), @@ -1880,7 +1872,6 @@ def test_incremental_encode(self): self.assertEqual(encoder.encode("ample.org."), b"xn--xample-9ta.org.") self.assertEqual(encoder.encode("", True), b"") - @unittest.expectedFailure # TODO: RUSTPYTHON; UnicodeError: label empty or too long def test_incremental_encode_invalid(self): iterencode_testcases = [ (f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"{'\xff'*60}", 0, 60, "")), diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py index 5b07b3c85b..27bbe0b64a 100644 --- a/Lib/test/test_compile.py +++ b/Lib/test/test_compile.py @@ -885,8 +885,6 @@ def foo(x): self.assertIn('LOAD_ATTR', instructions) self.assertIn('PRECALL', instructions) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_lineno_procedure_call(self): def call(): ( diff --git a/Lib/test/test_fnmatch.py b/Lib/test/test_fnmatch.py index 092cd56285..b977b8f8eb 100644 --- a/Lib/test/test_fnmatch.py +++ b/Lib/test/test_fnmatch.py @@ -71,7 +71,7 @@ def test_fnmatchcase(self): check('usr/bin', 'usr\\bin', False, fnmatchcase) check('usr\\bin', 'usr\\bin', True, fnmatchcase) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") + @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON') def test_bytes(self): self.check_match(b'test', b'te*') self.check_match(b'test\xff', b'te*\xff') @@ -239,17 +239,9 @@ def test_translate(self): self.assertEqual(translate('A*********?[?]?'), r'(?s:A.*.[?].)\Z') # fancy translation to prevent exponential-time match failure t = translate('**a*a****a') - digits = re.findall(r'\d+', t) - self.assertEqual(len(digits), 4) - self.assertEqual(digits[0], digits[1]) - self.assertEqual(digits[2], digits[3]) - g1 = f"g{digits[0]}" # e.g., group name "g4" - g2 = f"g{digits[2]}" # e.g., group name "g5" - self.assertEqual(t, - fr'(?s:(?=(?P<{g1}>.*?a))(?P={g1})(?=(?P<{g2}>.*?a))(?P={g2}).*a)\Z') + self.assertEqual(t, r'(?s:(?>.*?a)(?>.*?a).*a)\Z') # and try pasting multiple translate results - it's an undocumented - # feature that this works; all the pain of generating unique group - # names across calls exists to support this + # feature that this works r1 = translate('**a**a**a*') r2 = translate('**b**b**b*') r3 = translate('*c*c*c*') diff --git a/Lib/test/test_fstring.py b/Lib/test/test_fstring.py index 4996eedc1c..cc9f066b14 100644 --- a/Lib/test/test_fstring.py +++ b/Lib/test/test_fstring.py @@ -9,6 +9,7 @@ import ast import datetime +import dis import os import re import types @@ -19,7 +20,7 @@ from test.support.os_helper import temp_cwd from test.support.script_helper import assert_python_failure, assert_python_ok -a_global = "global variable" +a_global = 'global variable' # You could argue that I'm too strict in looking for specific error # values with assertRaisesRegex, but without it it's way too easy to @@ -28,7 +29,6 @@ # worthwhile tradeoff. When I switched to this method, I found many # examples where I wasn't testing what I thought I was. - class TestCase(unittest.TestCase): def assertAllRaise(self, exception_type, regex, error_strings): for str in error_strings: @@ -40,45 +40,43 @@ def test__format__lookup(self): # Make sure __format__ is looked up on the type, not the instance. class X: def __format__(self, spec): - return "class" + return 'class' x = X() # Add a bound __format__ method to the 'y' instance, but not # the 'x' instance. y = X() - y.__format__ = types.MethodType(lambda self, spec: "instance", y) + y.__format__ = types.MethodType(lambda self, spec: 'instance', y) - self.assertEqual(f"{y}", format(y)) - self.assertEqual(f"{y}", "class") + self.assertEqual(f'{y}', format(y)) + self.assertEqual(f'{y}', 'class') self.assertEqual(format(x), format(y)) # __format__ is not called this way, but still make sure it # returns what we expect (so we can make sure we're bypassing # it). - self.assertEqual(x.__format__(""), "class") - self.assertEqual(y.__format__(""), "instance") + self.assertEqual(x.__format__(''), 'class') + self.assertEqual(y.__format__(''), 'instance') # This is how __format__ is actually called. - self.assertEqual(type(x).__format__(x, ""), "class") - self.assertEqual(type(y).__format__(y, ""), "class") + self.assertEqual(type(x).__format__(x, ''), 'class') + self.assertEqual(type(y).__format__(y, ''), 'class') def test_ast(self): # Inspired by http://bugs.python.org/issue24975 class X: def __init__(self): self.called = False - def __call__(self): self.called = True return 4 - x = X() expr = """ a = 10 f'{a * x()}'""" t = ast.parse(expr) - c = compile(t, "", "exec") + c = compile(t, '', 'exec') # Make sure x was not called. self.assertFalse(x.called) @@ -284,6 +282,7 @@ def test_ast_line_numbers_duplicate_expression(self): self.assertEqual(binop.right.col_offset, 27) def test_ast_numbers_fstring_with_formatting(self): + t = ast.parse('f"Here is that pesky {xxx:.3f} again"') self.assertEqual(len(t.body), 1) self.assertEqual(t.body[0].lineno, 1) @@ -384,8 +383,7 @@ def test_ast_line_numbers_multiline_fstring(self): self.assertEqual(t.body[0].value.values[1].value.col_offset, 11) self.assertEqual(t.body[0].value.values[1].value.end_col_offset, 16) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_ast_line_numbers_with_parentheses(self): expr = """ x = ( @@ -441,12 +439,24 @@ def test_ast_line_numbers_with_parentheses(self): x, y = t.body # Check the single quoted string offsets first. - offsets = [(elt.col_offset, elt.end_col_offset) for elt in x.value.elts] - self.assertTrue(all(offset == (4, 10) for offset in offsets)) + offsets = [ + (elt.col_offset, elt.end_col_offset) + for elt in x.value.elts + ] + self.assertTrue(all( + offset == (4, 10) + for offset in offsets + )) # Check the triple quoted string offsets. - offsets = [(elt.col_offset, elt.end_col_offset) for elt in y.value.elts] - self.assertTrue(all(offset == (4, 14) for offset in offsets)) + offsets = [ + (elt.col_offset, elt.end_col_offset) + for elt in y.value.elts + ] + self.assertTrue(all( + offset == (4, 14) + for offset in offsets + )) expr = """ x = ( @@ -507,612 +517,530 @@ def test_ast_fstring_empty_format_spec(self): self.assertEqual(type(format_spec), ast.JoinedStr) self.assertEqual(len(format_spec.values), 0) + def test_ast_fstring_format_spec(self): + expr = "f'{1:{name}}'" + + mod = ast.parse(expr) + self.assertEqual(type(mod), ast.Module) + self.assertEqual(len(mod.body), 1) + + fstring = mod.body[0].value + self.assertEqual(type(fstring), ast.JoinedStr) + self.assertEqual(len(fstring.values), 1) + + fv = fstring.values[0] + self.assertEqual(type(fv), ast.FormattedValue) + + format_spec = fv.format_spec + self.assertEqual(type(format_spec), ast.JoinedStr) + self.assertEqual(len(format_spec.values), 1) + + format_spec_value = format_spec.values[0] + self.assertEqual(type(format_spec_value), ast.FormattedValue) + self.assertEqual(format_spec_value.value.id, 'name') + + expr = "f'{1:{name1}{name2}}'" + + mod = ast.parse(expr) + self.assertEqual(type(mod), ast.Module) + self.assertEqual(len(mod.body), 1) + + fstring = mod.body[0].value + self.assertEqual(type(fstring), ast.JoinedStr) + self.assertEqual(len(fstring.values), 1) + + fv = fstring.values[0] + self.assertEqual(type(fv), ast.FormattedValue) + + format_spec = fv.format_spec + self.assertEqual(type(format_spec), ast.JoinedStr) + self.assertEqual(len(format_spec.values), 2) + + format_spec_value = format_spec.values[0] + self.assertEqual(type(format_spec_value), ast.FormattedValue) + self.assertEqual(format_spec_value.value.id, 'name1') + + format_spec_value = format_spec.values[1] + self.assertEqual(type(format_spec_value), ast.FormattedValue) + self.assertEqual(format_spec_value.value.id, 'name2') + + def test_docstring(self): def f(): - f"""Not a docstring""" - + f'''Not a docstring''' self.assertIsNone(f.__doc__) - def g(): - """Not a docstring""" f"" - + '''Not a docstring''' \ + f'' self.assertIsNone(g.__doc__) def test_literal_eval(self): - with self.assertRaisesRegex(ValueError, "malformed node or string"): + with self.assertRaisesRegex(ValueError, 'malformed node or string'): ast.literal_eval("f'x'") def test_ast_compile_time_concat(self): - x = [""] + x = [''] expr = """x[0] = 'foo' f'{3}'""" t = ast.parse(expr) - c = compile(t, "", "exec") + c = compile(t, '', 'exec') exec(c) - self.assertEqual(x[0], "foo3") + self.assertEqual(x[0], 'foo3') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_compile_time_concat_errors(self): - self.assertAllRaise( - SyntaxError, - "cannot mix bytes and nonbytes literals", - [ - r"""f'' b''""", - r"""b'' f''""", - ], - ) + self.assertAllRaise(SyntaxError, + 'cannot mix bytes and nonbytes literals', + [r"""f'' b''""", + r"""b'' f''""", + ]) def test_literal(self): - self.assertEqual(f"", "") - self.assertEqual(f"a", "a") - self.assertEqual(f" ", " ") + self.assertEqual(f'', '') + self.assertEqual(f'a', 'a') + self.assertEqual(f' ', ' ') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_unterminated_string(self): - self.assertAllRaise( - SyntaxError, - "unterminated string", - [ - r"""f'{"x'""", - r"""f'{"x}'""", - r"""f'{("x'""", - r"""f'{("x}'""", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, 'unterminated string', + [r"""f'{"x'""", + r"""f'{"x}'""", + r"""f'{("x'""", + r"""f'{("x}'""", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI") def test_mismatched_parens(self): - self.assertAllRaise( - SyntaxError, - r"closing parenthesis '\}' " r"does not match opening parenthesis '\('", - [ - "f'{((}'", - ], - ) - self.assertAllRaise( - SyntaxError, - r"closing parenthesis '\)' " r"does not match opening parenthesis '\['", - [ - "f'{a[4)}'", - ], - ) - self.assertAllRaise( - SyntaxError, - r"closing parenthesis '\]' " r"does not match opening parenthesis '\('", - [ - "f'{a(4]}'", - ], - ) - self.assertAllRaise( - SyntaxError, - r"closing parenthesis '\}' " r"does not match opening parenthesis '\['", - [ - "f'{a[4}'", - ], - ) - self.assertAllRaise( - SyntaxError, - r"closing parenthesis '\}' " r"does not match opening parenthesis '\('", - [ - "f'{a(4}'", - ], - ) - self.assertRaises(SyntaxError, eval, "f'{" + "(" * 500 + "}'") - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, r"closing parenthesis '\}' " + r"does not match opening parenthesis '\('", + ["f'{((}'", + ]) + self.assertAllRaise(SyntaxError, r"closing parenthesis '\)' " + r"does not match opening parenthesis '\['", + ["f'{a[4)}'", + ]) + self.assertAllRaise(SyntaxError, r"closing parenthesis '\]' " + r"does not match opening parenthesis '\('", + ["f'{a(4]}'", + ]) + self.assertAllRaise(SyntaxError, r"closing parenthesis '\}' " + r"does not match opening parenthesis '\['", + ["f'{a[4}'", + ]) + self.assertAllRaise(SyntaxError, r"closing parenthesis '\}' " + r"does not match opening parenthesis '\('", + ["f'{a(4}'", + ]) + self.assertRaises(SyntaxError, eval, "f'{" + "("*500 + "}'") + + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI") def test_fstring_nested_too_deeply(self): - self.assertAllRaise( - SyntaxError, - "f-string: expressions nested too deeply", - ['f"{1+2:{1+2:{1+1:{1}}}}"'], - ) + self.assertAllRaise(SyntaxError, + "f-string: expressions nested too deeply", + ['f"{1+2:{1+2:{1+1:{1}}}}"']) def create_nested_fstring(n): if n == 0: return "1+1" - prev = create_nested_fstring(n - 1) + prev = create_nested_fstring(n-1) return f'f"{{{prev}}}"' - self.assertAllRaise( - SyntaxError, "too many nested f-strings", [create_nested_fstring(160)] - ) + self.assertAllRaise(SyntaxError, + "too many nested f-strings", + [create_nested_fstring(160)]) def test_syntax_error_in_nested_fstring(self): # See gh-104016 for more information on this crash - self.assertAllRaise( - SyntaxError, "invalid syntax", ['f"{1 1:' + ('{f"1:' * 199)] - ) + self.assertAllRaise(SyntaxError, + "invalid syntax", + ['f"{1 1:' + ('{f"1:' * 199)]) def test_double_braces(self): - self.assertEqual(f"{{", "{") - self.assertEqual(f"a{{", "a{") - self.assertEqual(f"{{b", "{b") - self.assertEqual(f"a{{b", "a{b") - self.assertEqual(f"}}", "}") - self.assertEqual(f"a}}", "a}") - self.assertEqual(f"}}b", "}b") - self.assertEqual(f"a}}b", "a}b") - self.assertEqual(f"{{}}", "{}") - self.assertEqual(f"a{{}}", "a{}") - self.assertEqual(f"{{b}}", "{b}") - self.assertEqual(f"{{}}c", "{}c") - self.assertEqual(f"a{{b}}", "a{b}") - self.assertEqual(f"a{{}}c", "a{}c") - self.assertEqual(f"{{b}}c", "{b}c") - self.assertEqual(f"a{{b}}c", "a{b}c") - - self.assertEqual(f"{{{10}", "{10") - self.assertEqual(f"}}{10}", "}10") - self.assertEqual(f"}}{{{10}", "}{10") - self.assertEqual(f"}}a{{{10}", "}a{10") - - self.assertEqual(f"{10}{{", "10{") - self.assertEqual(f"{10}}}", "10}") - self.assertEqual(f"{10}}}{{", "10}{") - self.assertEqual(f"{10}}}a{{" "}", "10}a{}") + self.assertEqual(f'{{', '{') + self.assertEqual(f'a{{', 'a{') + self.assertEqual(f'{{b', '{b') + self.assertEqual(f'a{{b', 'a{b') + self.assertEqual(f'}}', '}') + self.assertEqual(f'a}}', 'a}') + self.assertEqual(f'}}b', '}b') + self.assertEqual(f'a}}b', 'a}b') + self.assertEqual(f'{{}}', '{}') + self.assertEqual(f'a{{}}', 'a{}') + self.assertEqual(f'{{b}}', '{b}') + self.assertEqual(f'{{}}c', '{}c') + self.assertEqual(f'a{{b}}', 'a{b}') + self.assertEqual(f'a{{}}c', 'a{}c') + self.assertEqual(f'{{b}}c', '{b}c') + self.assertEqual(f'a{{b}}c', 'a{b}c') + + self.assertEqual(f'{{{10}', '{10') + self.assertEqual(f'}}{10}', '}10') + self.assertEqual(f'}}{{{10}', '}{10') + self.assertEqual(f'}}a{{{10}', '}a{10') + + self.assertEqual(f'{10}{{', '10{') + self.assertEqual(f'{10}}}', '10}') + self.assertEqual(f'{10}}}{{', '10}{') + self.assertEqual(f'{10}}}a{{' '}', '10}a{}') # Inside of strings, don't interpret doubled brackets. - self.assertEqual(f'{"{{}}"}', "{{}}") - - self.assertAllRaise( - TypeError, - "unhashable type", - [ - "f'{ {{}} }'", # dict in a set - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(f'{"{{}}"}', '{{}}') + + self.assertAllRaise(TypeError, 'unhashable type', + ["f'{ {{}} }'", # dict in a set + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_compile_time_concat(self): - x = "def" - self.assertEqual("abc" f"## {x}ghi", "abc## defghi") - self.assertEqual("abc" f"{x}" "ghi", "abcdefghi") - self.assertEqual("abc" f"{x}" "gh" f"i{x:4}", "abcdefghidef ") - self.assertEqual("{x}" f"{x}", "{x}def") - self.assertEqual("{x" f"{x}", "{xdef") - self.assertEqual("{x}" f"{x}", "{x}def") - self.assertEqual("{{x}}" f"{x}", "{{x}}def") - self.assertEqual("{{x" f"{x}", "{{xdef") - self.assertEqual("x}}" f"{x}", "x}}def") - self.assertEqual(f"{x}" "x}}", "defx}}") - self.assertEqual(f"{x}" "", "def") - self.assertEqual("" f"{x}" "", "def") - self.assertEqual("" f"{x}", "def") - self.assertEqual(f"{x}" "2", "def2") - self.assertEqual("1" f"{x}" "2", "1def2") - self.assertEqual("1" f"{x}", "1def") - self.assertEqual(f"{x}" f"-{x}", "def-def") - self.assertEqual("" f"", "") - self.assertEqual("" f"" "", "") - self.assertEqual("" f"" "" f"", "") - self.assertEqual(f"", "") - self.assertEqual(f"" "", "") - self.assertEqual(f"" "" f"", "") - self.assertEqual(f"" "" f"" "", "") + x = 'def' + self.assertEqual('abc' f'## {x}ghi', 'abc## defghi') + self.assertEqual('abc' f'{x}' 'ghi', 'abcdefghi') + self.assertEqual('abc' f'{x}' 'gh' f'i{x:4}', 'abcdefghidef ') + self.assertEqual('{x}' f'{x}', '{x}def') + self.assertEqual('{x' f'{x}', '{xdef') + self.assertEqual('{x}' f'{x}', '{x}def') + self.assertEqual('{{x}}' f'{x}', '{{x}}def') + self.assertEqual('{{x' f'{x}', '{{xdef') + self.assertEqual('x}}' f'{x}', 'x}}def') + self.assertEqual(f'{x}' 'x}}', 'defx}}') + self.assertEqual(f'{x}' '', 'def') + self.assertEqual('' f'{x}' '', 'def') + self.assertEqual('' f'{x}', 'def') + self.assertEqual(f'{x}' '2', 'def2') + self.assertEqual('1' f'{x}' '2', '1def2') + self.assertEqual('1' f'{x}', '1def') + self.assertEqual(f'{x}' f'-{x}', 'def-def') + self.assertEqual('' f'', '') + self.assertEqual('' f'' '', '') + self.assertEqual('' f'' '' f'', '') + self.assertEqual(f'', '') + self.assertEqual(f'' '', '') + self.assertEqual(f'' '' f'', '') + self.assertEqual(f'' '' f'' '', '') # This is not really [f'{'] + [f'}'] since we treat the inside # of braces as a purely new context, so it is actually f'{ and # then eval(' f') (a valid expression) and then }' which would # constitute a valid f-string. - # TODO: RUSTPYTHON SyntaxError - # self.assertEqual(f'{' f'}', " f") - - self.assertAllRaise( - SyntaxError, - "expecting '}'", - [ - '''f'{3' f"}"''', # can't concat to get a valid f-string - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(f'{' f'}', ' f') + + self.assertAllRaise(SyntaxError, "expecting '}'", + ['''f'{3' f"}"''', # can't concat to get a valid f-string + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_comments(self): # These aren't comments, since they're in strings. - d = {"#": "hash"} - self.assertEqual(f'{"#"}', "#") - self.assertEqual(f'{d["#"]}', "hash") - - self.assertAllRaise( - SyntaxError, - "'{' was never closed", - [ - "f'{1#}'", # error because everything after '#' is a comment - "f'{#}'", - "f'one: {1#}'", - "f'{1# one} {2 this is a comment still#}'", - ], - ) - self.assertAllRaise( - SyntaxError, - r"f-string: unmatched '\)'", - [ - "f'{)#}'", # When wrapped in parens, this becomes - # '()#)'. Make sure that doesn't compile. - ], - ) - self.assertEqual( - f"""A complex trick: { + d = {'#': 'hash'} + self.assertEqual(f'{"#"}', '#') + self.assertEqual(f'{d["#"]}', 'hash') + + self.assertAllRaise(SyntaxError, "'{' was never closed", + ["f'{1#}'", # error because everything after '#' is a comment + "f'{#}'", + "f'one: {1#}'", + "f'{1# one} {2 this is a comment still#}'", + ]) + self.assertAllRaise(SyntaxError, r"f-string: unmatched '\)'", + ["f'{)#}'", # When wrapped in parens, this becomes + # '()#)'. Make sure that doesn't compile. + ]) + self.assertEqual(f'''A complex trick: { 2 # two -}""", - "A complex trick: 2", - ) - self.assertEqual( - f""" +}''', 'A complex trick: 2') + self.assertEqual(f''' { 40 # forty + # plus 2 # two -}""", - "\n42", - ) - self.assertEqual( - f""" +}''', '\n42') + self.assertEqual(f''' { 40 # forty + # plus 2 # two -}""", - "\n42", - ) -# TODO: RUSTPYTHON SyntaxError -# self.assertEqual( -# f""" -# # this is not a comment -# { # the following operation it's -# 3 # this is a number -# * 2}""", -# "\n# this is not a comment\n6", -# ) - self.assertEqual( - f""" +}''', '\n42') + + self.assertEqual(f''' +# this is not a comment +{ # the following operation it's +3 # this is a number +* 2}''', '\n# this is not a comment\n6') + self.assertEqual(f''' {# f'a {comment}' 86 # constant # nothing more -}""", - "\n86", - ) - - self.assertAllRaise( - SyntaxError, - r"f-string: valid expression required before '}'", - [ - """f''' +}''', '\n86') + + self.assertAllRaise(SyntaxError, r"f-string: valid expression required before '}'", + ["""f''' { # only a comment }''' -""", # this is equivalent to f'{}' - ], - ) +""", # this is equivalent to f'{}' + ]) def test_many_expressions(self): # Create a string with many expressions in it. Note that # because we have a space in here as a literal, we're actually # going to use twice as many ast nodes: one for each literal # plus one for each expression. - def build_fstr(n, extra=""): - return "f'" + ("{x} " * n) + extra + "'" + def build_fstr(n, extra=''): + return "f'" + ('{x} ' * n) + extra + "'" - x = "X" + x = 'X' width = 1 # Test around 256. for i in range(250, 260): - self.assertEqual(eval(build_fstr(i)), (x + " ") * i) + self.assertEqual(eval(build_fstr(i)), (x+' ')*i) # Test concatenating 2 largs fstrings. - self.assertEqual(eval(build_fstr(255) * 256), (x + " ") * (255 * 256)) + self.assertEqual(eval(build_fstr(255)*256), (x+' ')*(255*256)) - s = build_fstr(253, "{x:{width}} ") - self.assertEqual(eval(s), (x + " ") * 254) + s = build_fstr(253, '{x:{width}} ') + self.assertEqual(eval(s), (x+' ')*254) # Test lots of expressions and constants, concatenated. s = "f'{1}' 'x' 'y'" * 1024 - self.assertEqual(eval(s), "1xy" * 1024) + self.assertEqual(eval(s), '1xy' * 1024) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_format_specifier_expressions(self): width = 10 precision = 4 - value = decimal.Decimal("12.34567") - self.assertEqual(f"result: {value:{width}.{precision}}", "result: 12.35") - self.assertEqual(f"result: {value:{width!r}.{precision}}", "result: 12.35") - self.assertEqual( - f"result: {value:{width:0}.{precision:1}}", "result: 12.35" - ) - self.assertEqual( - f"result: {value:{1}{0:0}.{precision:1}}", "result: 12.35" - ) - self.assertEqual( - f"result: {value:{ 1}{ 0:0}.{ precision:1}}", "result: 12.35" - ) - self.assertEqual(f"{10:#{1}0x}", " 0xa") - self.assertEqual(f'{10:{"#"}1{0}{"x"}}', " 0xa") - self.assertEqual(f'{-10:-{"#"}1{0}x}', " -0xa") - self.assertEqual(f'{-10:{"-"}#{1}0{"x"}}', " -0xa") - self.assertEqual(f"{10:#{3 != {4:5} and width}x}", " 0xa") - - # TODO: RUSTPYTHON SyntaxError - # self.assertEqual( - # f"result: {value:{width:{0}}.{precision:1}}", "result: 12.35" - # ) - - - self.assertAllRaise( - SyntaxError, - "f-string: expecting ':' or '}'", - [ - """f'{"s"!r{":10"}}'""", - # This looks like a nested format spec. - ], - ) - - - self.assertAllRaise( - SyntaxError, - "f-string: expecting a valid expression after '{'", - [ # Invalid syntax inside a nested spec. - "f'{4:{/5}}'", - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: invalid conversion character", - [ # No expansion inside conversion or for - # the : or ! itself. - """f'{"s"!{"r"}}'""", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + value = decimal.Decimal('12.34567') + self.assertEqual(f'result: {value:{width}.{precision}}', 'result: 12.35') + self.assertEqual(f'result: {value:{width!r}.{precision}}', 'result: 12.35') + self.assertEqual(f'result: {value:{width:0}.{precision:1}}', 'result: 12.35') + self.assertEqual(f'result: {value:{1}{0:0}.{precision:1}}', 'result: 12.35') + self.assertEqual(f'result: {value:{ 1}{ 0:0}.{ precision:1}}', 'result: 12.35') + self.assertEqual(f'{10:#{1}0x}', ' 0xa') + self.assertEqual(f'{10:{"#"}1{0}{"x"}}', ' 0xa') + self.assertEqual(f'{-10:-{"#"}1{0}x}', ' -0xa') + self.assertEqual(f'{-10:{"-"}#{1}0{"x"}}', ' -0xa') + self.assertEqual(f'{10:#{3 != {4:5} and width}x}', ' 0xa') + self.assertEqual(f'result: {value:{width:{0}}.{precision:1}}', 'result: 12.35') + + self.assertAllRaise(SyntaxError, "f-string: expecting ':' or '}'", + ["""f'{"s"!r{":10"}}'""", + # This looks like a nested format spec. + ]) + + self.assertAllRaise(SyntaxError, + "f-string: expecting a valid expression after '{'", + [# Invalid syntax inside a nested spec. + "f'{4:{/5}}'", + ]) + + self.assertAllRaise(SyntaxError, 'f-string: invalid conversion character', + [# No expansion inside conversion or for + # the : or ! itself. + """f'{"s"!{"r"}}'""", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_custom_format_specifier(self): class CustomFormat: def __format__(self, format_spec): return format_spec - self.assertEqual(f"{CustomFormat():\n}", "\n") - self.assertEqual(f"{CustomFormat():\u2603}", "☃") + self.assertEqual(f'{CustomFormat():\n}', '\n') + self.assertEqual(f'{CustomFormat():\u2603}', '☃') with self.assertWarns(SyntaxWarning): - exec(r'f"{F():¯\_(ツ)_/¯}"', {"F": CustomFormat}) + exec(r'f"{F():¯\_(ツ)_/¯}"', {'F': CustomFormat}) def test_side_effect_order(self): class X: def __init__(self): self.i = 0 - def __format__(self, spec): self.i += 1 return str(self.i) x = X() - self.assertEqual(f"{x} {x}", "1 2") + self.assertEqual(f'{x} {x}', '1 2') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_missing_expression(self): - self.assertAllRaise( - SyntaxError, - "f-string: valid expression required before '}'", - [ - "f'{}'", - "f'{ }'" "f' {} '", - "f'{10:{ }}'", - "f' { } '", - # The Python parser ignores also the following - # whitespace characters in additional to a space. - "f'''{\t\f\r\n}'''", - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: valid expression required before '!'", - [ - "f'{!r}'", - "f'{ !r}'", - "f'{!}'", - "f'''{\t\f\r\n!a}'''", - # Catch empty expression before the - # missing closing brace. - "f'{!'", - "f'{!s:'", - # Catch empty expression before the - # invalid conversion. - "f'{!x}'", - "f'{ !xr}'", - "f'{!x:}'", - "f'{!x:a}'", - "f'{ !xr:}'", - "f'{ !xr:a}'", - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: valid expression required before ':'", - [ - "f'{:}'", - "f'{ :!}'", - "f'{:2}'", - "f'''{\t\f\r\n:a}'''", - "f'{:'", - "F'{[F'{:'}[F'{:'}]]]", - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: valid expression required before '='", - [ - "f'{=}'", - "f'{ =}'", - "f'{ =:}'", - "f'{ =!}'", - "f'''{\t\f\r\n=}'''", - "f'{='", - ], - ) + self.assertAllRaise(SyntaxError, + "f-string: valid expression required before '}'", + ["f'{}'", + "f'{ }'" + "f' {} '", + "f'{10:{ }}'", + "f' { } '", + + # The Python parser ignores also the following + # whitespace characters in additional to a space. + "f'''{\t\f\r\n}'''", + ]) + + self.assertAllRaise(SyntaxError, + "f-string: valid expression required before '!'", + ["f'{!r}'", + "f'{ !r}'", + "f'{!}'", + "f'''{\t\f\r\n!a}'''", + + # Catch empty expression before the + # missing closing brace. + "f'{!'", + "f'{!s:'", + + # Catch empty expression before the + # invalid conversion. + "f'{!x}'", + "f'{ !xr}'", + "f'{!x:}'", + "f'{!x:a}'", + "f'{ !xr:}'", + "f'{ !xr:a}'", + ]) + + self.assertAllRaise(SyntaxError, + "f-string: valid expression required before ':'", + ["f'{:}'", + "f'{ :!}'", + "f'{:2}'", + "f'''{\t\f\r\n:a}'''", + "f'{:'", + "F'{[F'{:'}[F'{:'}]]]", + ]) + + self.assertAllRaise(SyntaxError, + "f-string: valid expression required before '='", + ["f'{=}'", + "f'{ =}'", + "f'{ =:}'", + "f'{ =!}'", + "f'''{\t\f\r\n=}'''", + "f'{='", + ]) # Different error message is raised for other whitespace characters. - self.assertAllRaise( - SyntaxError, - r"invalid non-printable character U\+00A0", - [ - "f'''{\xa0}'''", - "\xa0", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, r"invalid non-printable character U\+00A0", + ["f'''{\xa0}'''", + "\xa0", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_parens_in_expressions(self): - self.assertEqual(f"{3,}", "(3,)") - - self.assertAllRaise( - SyntaxError, - "f-string: expecting a valid expression after '{'", - [ - "f'{,}'", - ], - ) - - self.assertAllRaise( - SyntaxError, - r"f-string: unmatched '\)'", - [ - "f'{3)+(4}'", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(f'{3,}', '(3,)') + + self.assertAllRaise(SyntaxError, + "f-string: expecting a valid expression after '{'", + ["f'{,}'", + ]) + + self.assertAllRaise(SyntaxError, r"f-string: unmatched '\)'", + ["f'{3)+(4}'", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_newlines_before_syntax_error(self): - self.assertAllRaise( - SyntaxError, - "f-string: expecting a valid expression after '{'", - ["f'{.}'", "\nf'{.}'", "\n\nf'{.}'"], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, + "f-string: expecting a valid expression after '{'", + ["f'{.}'", "\nf'{.}'", "\n\nf'{.}'"]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_backslashes_in_string_part(self): - self.assertEqual(f"\t", "\t") - self.assertEqual(r"\t", "\\t") - self.assertEqual(rf"\t", "\\t") - self.assertEqual(f"{2}\t", "2\t") - self.assertEqual(f"{2}\t{3}", "2\t3") - self.assertEqual(f"\t{3}", "\t3") - - self.assertEqual(f"\u0394", "\u0394") - self.assertEqual(r"\u0394", "\\u0394") - self.assertEqual(rf"\u0394", "\\u0394") - self.assertEqual(f"{2}\u0394", "2\u0394") - self.assertEqual(f"{2}\u0394{3}", "2\u03943") - self.assertEqual(f"\u0394{3}", "\u03943") - - self.assertEqual(f"\U00000394", "\u0394") - self.assertEqual(r"\U00000394", "\\U00000394") - self.assertEqual(rf"\U00000394", "\\U00000394") - self.assertEqual(f"{2}\U00000394", "2\u0394") - self.assertEqual(f"{2}\U00000394{3}", "2\u03943") - self.assertEqual(f"\U00000394{3}", "\u03943") - - self.assertEqual(f"\N{GREEK CAPITAL LETTER DELTA}", "\u0394") - self.assertEqual(f"{2}\N{GREEK CAPITAL LETTER DELTA}", "2\u0394") - self.assertEqual(f"{2}\N{GREEK CAPITAL LETTER DELTA}{3}", "2\u03943") - self.assertEqual(f"\N{GREEK CAPITAL LETTER DELTA}{3}", "\u03943") - self.assertEqual(f"2\N{GREEK CAPITAL LETTER DELTA}", "2\u0394") - self.assertEqual(f"2\N{GREEK CAPITAL LETTER DELTA}3", "2\u03943") - self.assertEqual(f"\N{GREEK CAPITAL LETTER DELTA}3", "\u03943") - - self.assertEqual(f"\x20", " ") - self.assertEqual(r"\x20", "\\x20") - self.assertEqual(rf"\x20", "\\x20") - self.assertEqual(f"{2}\x20", "2 ") - self.assertEqual(f"{2}\x20{3}", "2 3") - self.assertEqual(f"\x20{3}", " 3") - - self.assertEqual(f"2\x20", "2 ") - self.assertEqual(f"2\x203", "2 3") - self.assertEqual(f"\x203", " 3") + self.assertEqual(f'\t', '\t') + self.assertEqual(r'\t', '\\t') + self.assertEqual(rf'\t', '\\t') + self.assertEqual(f'{2}\t', '2\t') + self.assertEqual(f'{2}\t{3}', '2\t3') + self.assertEqual(f'\t{3}', '\t3') + + self.assertEqual(f'\u0394', '\u0394') + self.assertEqual(r'\u0394', '\\u0394') + self.assertEqual(rf'\u0394', '\\u0394') + self.assertEqual(f'{2}\u0394', '2\u0394') + self.assertEqual(f'{2}\u0394{3}', '2\u03943') + self.assertEqual(f'\u0394{3}', '\u03943') + + self.assertEqual(f'\U00000394', '\u0394') + self.assertEqual(r'\U00000394', '\\U00000394') + self.assertEqual(rf'\U00000394', '\\U00000394') + self.assertEqual(f'{2}\U00000394', '2\u0394') + self.assertEqual(f'{2}\U00000394{3}', '2\u03943') + self.assertEqual(f'\U00000394{3}', '\u03943') + + self.assertEqual(f'\N{GREEK CAPITAL LETTER DELTA}', '\u0394') + self.assertEqual(f'{2}\N{GREEK CAPITAL LETTER DELTA}', '2\u0394') + self.assertEqual(f'{2}\N{GREEK CAPITAL LETTER DELTA}{3}', '2\u03943') + self.assertEqual(f'\N{GREEK CAPITAL LETTER DELTA}{3}', '\u03943') + self.assertEqual(f'2\N{GREEK CAPITAL LETTER DELTA}', '2\u0394') + self.assertEqual(f'2\N{GREEK CAPITAL LETTER DELTA}3', '2\u03943') + self.assertEqual(f'\N{GREEK CAPITAL LETTER DELTA}3', '\u03943') + + self.assertEqual(f'\x20', ' ') + self.assertEqual(r'\x20', '\\x20') + self.assertEqual(rf'\x20', '\\x20') + self.assertEqual(f'{2}\x20', '2 ') + self.assertEqual(f'{2}\x20{3}', '2 3') + self.assertEqual(f'\x20{3}', ' 3') + + self.assertEqual(f'2\x20', '2 ') + self.assertEqual(f'2\x203', '2 3') + self.assertEqual(f'\x203', ' 3') with self.assertWarns(SyntaxWarning): # invalid escape sequence value = eval(r"f'\{6*7}'") - self.assertEqual(value, "\\42") + self.assertEqual(value, '\\42') with self.assertWarns(SyntaxWarning): # invalid escape sequence value = eval(r"f'\g'") - self.assertEqual(value, "\\g") - self.assertEqual(f"\\{6*7}", "\\42") - self.assertEqual(rf"\{6*7}", "\\42") + self.assertEqual(value, '\\g') + self.assertEqual(f'\\{6*7}', '\\42') + self.assertEqual(fr'\{6*7}', '\\42') - AMPERSAND = "spam" + AMPERSAND = 'spam' # Get the right unicode character (&), or pick up local variable # depending on the number of backslashes. - self.assertEqual(f"\N{AMPERSAND}", "&") - self.assertEqual(f"\\N{AMPERSAND}", "\\Nspam") - self.assertEqual(rf"\N{AMPERSAND}", "\\Nspam") - self.assertEqual(f"\\\N{AMPERSAND}", "\\&") + self.assertEqual(f'\N{AMPERSAND}', '&') + self.assertEqual(f'\\N{AMPERSAND}', '\\Nspam') + self.assertEqual(fr'\N{AMPERSAND}', '\\Nspam') + self.assertEqual(f'\\\N{AMPERSAND}', '\\&') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_misformed_unicode_character_name(self): # These test are needed because unicode names are parsed # differently inside f-strings. - self.assertAllRaise( - SyntaxError, - r"\(unicode error\) 'unicodeescape' codec can't decode bytes in position .*: malformed \\N character escape", - [ - r"f'\N'", - r"f'\N '", - r"f'\N '", # See bpo-46503. - r"f'\N{'", - r"f'\N{GREEK CAPITAL LETTER DELTA'", - # Here are the non-f-string versions, - # which should give the same errors. - r"'\N'", - r"'\N '", - r"'\N '", - r"'\N{'", - r"'\N{GREEK CAPITAL LETTER DELTA'", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, r"\(unicode error\) 'unicodeescape' codec can't decode bytes in position .*: malformed \\N character escape", + [r"f'\N'", + r"f'\N '", + r"f'\N '", # See bpo-46503. + r"f'\N{'", + r"f'\N{GREEK CAPITAL LETTER DELTA'", + + # Here are the non-f-string versions, + # which should give the same errors. + r"'\N'", + r"'\N '", + r"'\N '", + r"'\N{'", + r"'\N{GREEK CAPITAL LETTER DELTA'", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_backslashes_in_expression_part(self): - # TODO: RUSTPYTHON SyntaxError - # self.assertEqual( - # f"{( - # 1 + - # 2 - # )}", - # "3", - # ) - - self.assertEqual("\N{LEFT CURLY BRACKET}", "{") - self.assertEqual(f'{"\N{LEFT CURLY BRACKET}"}', "{") - self.assertEqual(rf'{"\N{LEFT CURLY BRACKET}"}', "{") - - self.assertAllRaise( - SyntaxError, - "f-string: valid expression required before '}'", - [ - "f'{\n}'", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(f"{( + 1 + + 2 + )}", "3") + + self.assertEqual("\N{LEFT CURLY BRACKET}", '{') + self.assertEqual(f'{"\N{LEFT CURLY BRACKET}"}', '{') + self.assertEqual(rf'{"\N{LEFT CURLY BRACKET}"}', '{') + + self.assertAllRaise(SyntaxError, + "f-string: valid expression required before '}'", + ["f'{\n}'", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_backslashes_inside_fstring_context(self): # All of these variations are invalid python syntax, # so they are also invalid in f-strings as well. @@ -1129,30 +1057,25 @@ def test_invalid_backslashes_inside_fstring_context(self): r"\\"[0], ] ] - self.assertAllRaise( - SyntaxError, "unexpected character after line continuation", cases - ) + self.assertAllRaise(SyntaxError, 'unexpected character after line continuation', + cases) def test_no_escapes_for_braces(self): """ Only literal curly braces begin an expression. """ # \x7b is '{'. - self.assertEqual(f"\x7b1+1}}", "{1+1}") - self.assertEqual(f"\x7b1+1", "{1+1") - self.assertEqual(f"\u007b1+1", "{1+1") - self.assertEqual(f"\N{LEFT CURLY BRACKET}1+1\N{RIGHT CURLY BRACKET}", "{1+1}") + self.assertEqual(f'\x7b1+1}}', '{1+1}') + self.assertEqual(f'\x7b1+1', '{1+1') + self.assertEqual(f'\u007b1+1', '{1+1') + self.assertEqual(f'\N{LEFT CURLY BRACKET}1+1\N{RIGHT CURLY BRACKET}', '{1+1}') def test_newlines_in_expressions(self): - self.assertEqual(f"{0}", "0") - self.assertEqual( - rf"""{3+ -4}""", - "7", - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(f'{0}', '0') + self.assertEqual(rf'''{3+ +4}''', '7') + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_lambda(self): x = 5 self.assertEqual(f'{(lambda y:x*y)("8")!r}', "'88888'") @@ -1162,67 +1085,60 @@ def test_lambda(self): # lambda doesn't work without parens, because the colon # makes the parser think it's a format_spec # emit warning if we can match a format_spec - self.assertAllRaise( - SyntaxError, - "f-string: lambda expressions are not allowed " "without parentheses", - [ - "f'{lambda x:x}'", - "f'{lambda :x}'", - "f'{lambda *arg, :x}'", - "f'{1, lambda:x}'", - "f'{lambda x:}'", - "f'{lambda :}'", - ], - ) + self.assertAllRaise(SyntaxError, + "f-string: lambda expressions are not allowed " + "without parentheses", + ["f'{lambda x:x}'", + "f'{lambda :x}'", + "f'{lambda *arg, :x}'", + "f'{1, lambda:x}'", + "f'{lambda x:}'", + "f'{lambda :}'", + ]) # Ensure the detection of invalid lambdas doesn't trigger detection # for valid lambdas in the second error pass with self.assertRaisesRegex(SyntaxError, "invalid syntax"): compile("lambda name_3=f'{name_4}': {name_3}\n1 $ 1", "", "exec") # but don't emit the paren warning in general cases - with self.assertRaisesRegex( - SyntaxError, "f-string: expecting a valid expression after '{'" - ): + with self.assertRaisesRegex(SyntaxError, "f-string: expecting a valid expression after '{'"): eval("f'{+ lambda:None}'") def test_valid_prefixes(self): - self.assertEqual(f"{1}", "1") - self.assertEqual(Rf"{2}", "2") - self.assertEqual(Rf"{3}", "3") + self.assertEqual(F'{1}', "1") + self.assertEqual(FR'{2}', "2") + self.assertEqual(fR'{3}', "3") def test_roundtrip_raw_quotes(self): - self.assertEqual(rf"\'", "\\'") - self.assertEqual(rf"\"", '\\"') - self.assertEqual(rf"\"\'", "\\\"\\'") - self.assertEqual(rf"\'\"", "\\'\\\"") - self.assertEqual(rf"\"\'\"", '\\"\\\'\\"') - self.assertEqual(rf"\'\"\'", "\\'\\\"\\'") - self.assertEqual(rf"\"\'\"\'", "\\\"\\'\\\"\\'") - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(fr"\'", "\\'") + self.assertEqual(fr'\"', '\\"') + self.assertEqual(fr'\"\'', '\\"\\\'') + self.assertEqual(fr'\'\"', '\\\'\\"') + self.assertEqual(fr'\"\'\"', '\\"\\\'\\"') + self.assertEqual(fr'\'\"\'', '\\\'\\"\\\'') + self.assertEqual(fr'\"\'\"\'', '\\"\\\'\\"\\\'') + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_fstring_backslash_before_double_bracket(self): deprecated_cases = [ - (r"f'\{{\}}'", "\\{\\}"), - (r"f'\{{'", "\\{"), - (r"f'\{{{1+1}'", "\\{2"), - (r"f'\}}{1+1}'", "\\}2"), - (r"f'{1+1}\}}'", "2\\}"), + (r"f'\{{\}}'", '\\{\\}'), + (r"f'\{{'", '\\{'), + (r"f'\{{{1+1}'", '\\{2'), + (r"f'\}}{1+1}'", '\\}2'), + (r"f'{1+1}\}}'", '2\\}') ] - for case, expected_result in deprecated_cases: with self.subTest(case=case, expected_result=expected_result): with self.assertWarns(SyntaxWarning): result = eval(case) self.assertEqual(result, expected_result) - self.assertEqual(rf"\{{\}}", "\\{\\}") - self.assertEqual(rf"\{{", "\\{") - self.assertEqual(rf"\{{{1+1}", "\\{2") - self.assertEqual(rf"\}}{1+1}", "\\}2") - self.assertEqual(rf"{1+1}\}}", "2\\}") - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(fr'\{{\}}', '\\{\\}') + self.assertEqual(fr'\{{', '\\{') + self.assertEqual(fr'\{{{1+1}', '\\{2') + self.assertEqual(fr'\}}{1+1}', '\\}2') + self.assertEqual(fr'{1+1}\}}', '2\\}') + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_fstring_backslash_before_double_bracket_warns_once(self): with self.assertWarns(SyntaxWarning) as w: eval(r"f'\{{'") @@ -1230,18 +1146,18 @@ def test_fstring_backslash_before_double_bracket_warns_once(self): self.assertEqual(w.warnings[0].category, SyntaxWarning) def test_fstring_backslash_prefix_raw(self): - self.assertEqual(f"\\", "\\") - self.assertEqual(f"\\\\", "\\\\") - self.assertEqual(rf"\\", r"\\") - self.assertEqual(rf"\\\\", r"\\\\") - self.assertEqual(rf"\\", r"\\") - self.assertEqual(rf"\\\\", r"\\\\") - self.assertEqual(Rf"\\", R"\\") - self.assertEqual(Rf"\\\\", R"\\\\") - self.assertEqual(Rf"\\", R"\\") - self.assertEqual(Rf"\\\\", R"\\\\") - self.assertEqual(Rf"\\", R"\\") - self.assertEqual(Rf"\\\\", R"\\\\") + self.assertEqual(f'\\', '\\') + self.assertEqual(f'\\\\', '\\\\') + self.assertEqual(fr'\\', r'\\') + self.assertEqual(fr'\\\\', r'\\\\') + self.assertEqual(rf'\\', r'\\') + self.assertEqual(rf'\\\\', r'\\\\') + self.assertEqual(Rf'\\', R'\\') + self.assertEqual(Rf'\\\\', R'\\\\') + self.assertEqual(fR'\\', R'\\') + self.assertEqual(fR'\\\\', R'\\\\') + self.assertEqual(FR'\\', R'\\') + self.assertEqual(FR'\\\\', R'\\\\') def test_fstring_format_spec_greedy_matching(self): self.assertEqual(f"{1:}}}", "1}") @@ -1251,8 +1167,8 @@ def test_yield(self): # Not terribly useful, but make sure the yield turns # a function into a generator def fn(y): - f"y:{yield y*2}" - f"{yield}" + f'y:{yield y*2}' + f'{yield}' g = fn(4) self.assertEqual(next(g), 8) @@ -1260,331 +1176,287 @@ def fn(y): def test_yield_send(self): def fn(x): - yield f"x:{yield (lambda i: x * i)}" + yield f'x:{yield (lambda i: x * i)}' g = fn(10) the_lambda = next(g) self.assertEqual(the_lambda(4), 40) - self.assertEqual(g.send("string"), "x:string") + self.assertEqual(g.send('string'), 'x:string') - # TODO: RUSTPYTHON SyntaxError - # def test_expressions_with_triple_quoted_strings(self): - # self.assertEqual(f"{'''x'''}", 'x') - # self.assertEqual(f"{'''eric's'''}", "eric's") + def test_expressions_with_triple_quoted_strings(self): + self.assertEqual(f"{'''x'''}", 'x') + self.assertEqual(f"{'''eric's'''}", "eric's") - # # Test concatenation within an expression - # self.assertEqual(f'{"x" """eric"s""" "y"}', 'xeric"sy') - # self.assertEqual(f'{"x" """eric"s"""}', 'xeric"s') - # self.assertEqual(f'{"""eric"s""" "y"}', 'eric"sy') - # self.assertEqual(f'{"""x""" """eric"s""" "y"}', 'xeric"sy') - # self.assertEqual(f'{"""x""" """eric"s""" """y"""}', 'xeric"sy') - # self.assertEqual(f'{r"""x""" """eric"s""" """y"""}', 'xeric"sy') + # Test concatenation within an expression + self.assertEqual(f'{"x" """eric"s""" "y"}', 'xeric"sy') + self.assertEqual(f'{"x" """eric"s"""}', 'xeric"s') + self.assertEqual(f'{"""eric"s""" "y"}', 'eric"sy') + self.assertEqual(f'{"""x""" """eric"s""" "y"}', 'xeric"sy') + self.assertEqual(f'{"""x""" """eric"s""" """y"""}', 'xeric"sy') + self.assertEqual(f'{r"""x""" """eric"s""" """y"""}', 'xeric"sy') def test_multiple_vars(self): x = 98 - y = "abc" - self.assertEqual(f"{x}{y}", "98abc") + y = 'abc' + self.assertEqual(f'{x}{y}', '98abc') - self.assertEqual(f"X{x}{y}", "X98abc") - self.assertEqual(f"{x}X{y}", "98Xabc") - self.assertEqual(f"{x}{y}X", "98abcX") + self.assertEqual(f'X{x}{y}', 'X98abc') + self.assertEqual(f'{x}X{y}', '98Xabc') + self.assertEqual(f'{x}{y}X', '98abcX') - self.assertEqual(f"X{x}Y{y}", "X98Yabc") - self.assertEqual(f"X{x}{y}Y", "X98abcY") - self.assertEqual(f"{x}X{y}Y", "98XabcY") + self.assertEqual(f'X{x}Y{y}', 'X98Yabc') + self.assertEqual(f'X{x}{y}Y', 'X98abcY') + self.assertEqual(f'{x}X{y}Y', '98XabcY') - self.assertEqual(f"X{x}Y{y}Z", "X98YabcZ") + self.assertEqual(f'X{x}Y{y}Z', 'X98YabcZ') def test_closure(self): def outer(x): def inner(): - return f"x:{x}" - + return f'x:{x}' return inner - self.assertEqual(outer("987")(), "x:987") - self.assertEqual(outer(7)(), "x:7") + self.assertEqual(outer('987')(), 'x:987') + self.assertEqual(outer(7)(), 'x:7') def test_arguments(self): y = 2 - def f(x, width): - return f"x={x*y:{width}}" + return f'x={x*y:{width}}' - self.assertEqual(f("foo", 10), "x=foofoo ") - x = "bar" - self.assertEqual(f(10, 10), "x= 20") + self.assertEqual(f('foo', 10), 'x=foofoo ') + x = 'bar' + self.assertEqual(f(10, 10), 'x= 20') def test_locals(self): value = 123 - self.assertEqual(f"v:{value}", "v:123") + self.assertEqual(f'v:{value}', 'v:123') def test_missing_variable(self): with self.assertRaises(NameError): - f"v:{value}" + f'v:{value}' def test_missing_format_spec(self): class O: def __format__(self, spec): if not spec: - return "*" + return '*' return spec - self.assertEqual(f"{O():x}", "x") - self.assertEqual(f"{O()}", "*") - self.assertEqual(f"{O():}", "*") + self.assertEqual(f'{O():x}', 'x') + self.assertEqual(f'{O()}', '*') + self.assertEqual(f'{O():}', '*') - self.assertEqual(f"{3:}", "3") - self.assertEqual(f"{3!s:}", "3") + self.assertEqual(f'{3:}', '3') + self.assertEqual(f'{3!s:}', '3') def test_global(self): - self.assertEqual(f"g:{a_global}", "g:global variable") - self.assertEqual(f"g:{a_global!r}", "g:'global variable'") + self.assertEqual(f'g:{a_global}', 'g:global variable') + self.assertEqual(f'g:{a_global!r}', "g:'global variable'") - a_local = "local variable" - self.assertEqual( - f"g:{a_global} l:{a_local}", "g:global variable l:local variable" - ) - self.assertEqual(f"g:{a_global!r}", "g:'global variable'") - self.assertEqual( - f"g:{a_global} l:{a_local!r}", "g:global variable l:'local variable'" - ) + a_local = 'local variable' + self.assertEqual(f'g:{a_global} l:{a_local}', + 'g:global variable l:local variable') + self.assertEqual(f'g:{a_global!r}', + "g:'global variable'") + self.assertEqual(f'g:{a_global} l:{a_local!r}', + "g:global variable l:'local variable'") - self.assertIn("module 'unittest' from", f"{unittest}") + self.assertIn("module 'unittest' from", f'{unittest}') def test_shadowed_global(self): - a_global = "really a local" - self.assertEqual(f"g:{a_global}", "g:really a local") - self.assertEqual(f"g:{a_global!r}", "g:'really a local'") - - a_local = "local variable" - self.assertEqual( - f"g:{a_global} l:{a_local}", "g:really a local l:local variable" - ) - self.assertEqual(f"g:{a_global!r}", "g:'really a local'") - self.assertEqual( - f"g:{a_global} l:{a_local!r}", "g:really a local l:'local variable'" - ) + a_global = 'really a local' + self.assertEqual(f'g:{a_global}', 'g:really a local') + self.assertEqual(f'g:{a_global!r}', "g:'really a local'") + + a_local = 'local variable' + self.assertEqual(f'g:{a_global} l:{a_local}', + 'g:really a local l:local variable') + self.assertEqual(f'g:{a_global!r}', + "g:'really a local'") + self.assertEqual(f'g:{a_global} l:{a_local!r}', + "g:really a local l:'local variable'") def test_call(self): def foo(x): - return "x=" + str(x) + return 'x=' + str(x) - self.assertEqual(f"{foo(10)}", "x=10") + self.assertEqual(f'{foo(10)}', 'x=10') def test_nested_fstrings(self): y = 5 - self.assertEqual(f'{f"{0}"*3}', "000") - self.assertEqual(f'{f"{y}"*3}', "555") + self.assertEqual(f'{f"{0}"*3}', '000') + self.assertEqual(f'{f"{y}"*3}', '555') def test_invalid_string_prefixes(self): - single_quote_cases = [ - "fu''", - "uf''", - "Fu''", - "fU''", - "Uf''", - "uF''", - "ufr''", - "urf''", - "fur''", - "fru''", - "rfu''", - "ruf''", - "FUR''", - "Fur''", - "fb''", - "fB''", - "Fb''", - "FB''", - "bf''", - "bF''", - "Bf''", - "BF''", - ] + single_quote_cases = ["fu''", + "uf''", + "Fu''", + "fU''", + "Uf''", + "uF''", + "ufr''", + "urf''", + "fur''", + "fru''", + "rfu''", + "ruf''", + "FUR''", + "Fur''", + "fb''", + "fB''", + "Fb''", + "FB''", + "bf''", + "bF''", + "Bf''", + "BF''",] double_quote_cases = [case.replace("'", '"') for case in single_quote_cases] - self.assertAllRaise( - SyntaxError, "invalid syntax", single_quote_cases + double_quote_cases - ) + self.assertAllRaise(SyntaxError, 'invalid syntax', + single_quote_cases + double_quote_cases) def test_leading_trailing_spaces(self): - self.assertEqual(f"{ 3}", "3") - self.assertEqual(f"{ 3}", "3") - self.assertEqual(f"{3 }", "3") - self.assertEqual(f"{3 }", "3") + self.assertEqual(f'{ 3}', '3') + self.assertEqual(f'{ 3}', '3') + self.assertEqual(f'{3 }', '3') + self.assertEqual(f'{3 }', '3') - self.assertEqual(f"expr={ {x: y for x, y in [(1, 2), ]}}", "expr={1: 2}") - self.assertEqual(f"expr={ {x: y for x, y in [(1, 2), ]} }", "expr={1: 2}") + self.assertEqual(f'expr={ {x: y for x, y in [(1, 2), ]}}', + 'expr={1: 2}') + self.assertEqual(f'expr={ {x: y for x, y in [(1, 2), ]} }', + 'expr={1: 2}') def test_not_equal(self): # There's a special test for this because there's a special # case in the f-string parser to look for != as not ending an # expression. Normally it would, while looking for !s or !r. - self.assertEqual(f"{3!=4}", "True") - self.assertEqual(f"{3!=4:}", "True") - self.assertEqual(f"{3!=4!s}", "True") - self.assertEqual(f"{3!=4!s:.3}", "Tru") + self.assertEqual(f'{3!=4}', 'True') + self.assertEqual(f'{3!=4:}', 'True') + self.assertEqual(f'{3!=4!s}', 'True') + self.assertEqual(f'{3!=4!s:.3}', 'Tru') def test_equal_equal(self): # Because an expression ending in = has special meaning, # there's a special test for ==. Make sure it works. - self.assertEqual(f"{0==1}", "False") + self.assertEqual(f'{0==1}', 'False') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_conversions(self): - self.assertEqual(f"{3.14:10.10}", " 3.14") - self.assertEqual(f"{3.14!s:10.10}", "3.14 ") - self.assertEqual(f"{3.14!r:10.10}", "3.14 ") - self.assertEqual(f"{3.14!a:10.10}", "3.14 ") + self.assertEqual(f'{3.14:10.10}', ' 3.14') + self.assertEqual(f'{1.25!s:10.10}', '1.25 ') + self.assertEqual(f'{1.25!r:10.10}', '1.25 ') + self.assertEqual(f'{1.25!a:10.10}', '1.25 ') - self.assertEqual(f'{"a"}', "a") + self.assertEqual(f'{"a"}', 'a') self.assertEqual(f'{"a"!r}', "'a'") self.assertEqual(f'{"a"!a}', "'a'") # Conversions can have trailing whitespace after them since it # does not provide any significance - # TODO: RUSTPYTHON SyntaxError - # self.assertEqual(f"{3!s }", "3") - # self.assertEqual(f"{3.14!s :10.10}", "3.14 ") + self.assertEqual(f"{3!s }", "3") + self.assertEqual(f'{1.25!s :10.10}', '1.25 ') # Not a conversion. self.assertEqual(f'{"a!r"}', "a!r") # Not a conversion, but show that ! is allowed in a format spec. - self.assertEqual(f"{3.14:!<10.10}", "3.14!!!!!!") - - self.assertAllRaise( - SyntaxError, - "f-string: expecting '}'", - [ - "f'{3!'", - "f'{3!s'", - "f'{3!g'", - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: missing conversion character", - [ - "f'{3!}'", - "f'{3!:'", - "f'{3!:}'", - ], - ) - - for conv_identifier in "g", "A", "G", "ä", "ɐ": - self.assertAllRaise( - SyntaxError, - "f-string: invalid conversion character %r: " - "expected 's', 'r', or 'a'" % conv_identifier, - ["f'{3!" + conv_identifier + "}'"], - ) - - for conv_non_identifier in "3", "!": - self.assertAllRaise( - SyntaxError, - "f-string: invalid conversion character", - ["f'{3!" + conv_non_identifier + "}'"], - ) - - for conv in " s", " s ": - self.assertAllRaise( - SyntaxError, - "f-string: conversion type must come right after the" - " exclamanation mark", - ["f'{3!" + conv + "}'"], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: invalid conversion character 'ss': " "expected 's', 'r', or 'a'", - [ - "f'{3!ss}'", - "f'{3!ss:}'", - "f'{3!ss:s}'", - ], - ) + self.assertEqual(f'{3.14:!<10.10}', '3.14!!!!!!') + + self.assertAllRaise(SyntaxError, "f-string: expecting '}'", + ["f'{3!'", + "f'{3!s'", + "f'{3!g'", + ]) + + self.assertAllRaise(SyntaxError, 'f-string: missing conversion character', + ["f'{3!}'", + "f'{3!:'", + "f'{3!:}'", + ]) + + for conv_identifier in 'g', 'A', 'G', 'ä', 'ɐ': + self.assertAllRaise(SyntaxError, + "f-string: invalid conversion character %r: " + "expected 's', 'r', or 'a'" % conv_identifier, + ["f'{3!" + conv_identifier + "}'"]) + + for conv_non_identifier in '3', '!': + self.assertAllRaise(SyntaxError, + "f-string: invalid conversion character", + ["f'{3!" + conv_non_identifier + "}'"]) + + for conv in ' s', ' s ': + self.assertAllRaise(SyntaxError, + "f-string: conversion type must come right after the" + " exclamation mark", + ["f'{3!" + conv + "}'"]) + + self.assertAllRaise(SyntaxError, + "f-string: invalid conversion character 'ss': " + "expected 's', 'r', or 'a'", + ["f'{3!ss}'", + "f'{3!ss:}'", + "f'{3!ss:s}'", + ]) def test_assignment(self): - self.assertAllRaise( - SyntaxError, - r"invalid syntax", - [ - "f'' = 3", - "f'{0}' = x", - "f'{x}' = x", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, r'invalid syntax', + ["f'' = 3", + "f'{0}' = x", + "f'{x}' = x", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_del(self): - self.assertAllRaise( - SyntaxError, - "invalid syntax", - [ - "del f''", - "del '' f''", - ], - ) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertAllRaise(SyntaxError, 'invalid syntax', + ["del f''", + "del '' f''", + ]) + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_mismatched_braces(self): - self.assertAllRaise( - SyntaxError, - "f-string: single '}' is not allowed", - [ - "f'{{}'", - "f'{{}}}'", - "f'}'", - "f'x}'", - "f'x}x'", - r"f'\u007b}'", - # Can't have { or } in a format spec. - "f'{3:}>10}'", - "f'{3:}}>10}'", - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: expecting '}'", - [ - "f'{3'", - "f'{3!'", - "f'{3:'", - "f'{3!s'", - "f'{3!s:'", - "f'{3!s:3'", - "f'x{'", - "f'x{x'", - "f'{x'", - "f'{3:s'", - "f'{{{'", - "f'{{}}{'", - "f'{'", - "f'{i='", # See gh-93418. - ], - ) - - self.assertAllRaise( - SyntaxError, - "f-string: expecting a valid expression after '{'", - [ - "f'{3:{{>10}'", - ], - ) + self.assertAllRaise(SyntaxError, "f-string: single '}' is not allowed", + ["f'{{}'", + "f'{{}}}'", + "f'}'", + "f'x}'", + "f'x}x'", + r"f'\u007b}'", + + # Can't have { or } in a format spec. + "f'{3:}>10}'", + "f'{3:}}>10}'", + ]) + + self.assertAllRaise(SyntaxError, "f-string: expecting '}'", + ["f'{3'", + "f'{3!'", + "f'{3:'", + "f'{3!s'", + "f'{3!s:'", + "f'{3!s:3'", + "f'x{'", + "f'x{x'", + "f'{x'", + "f'{3:s'", + "f'{{{'", + "f'{{}}{'", + "f'{'", + "f'{i='", # See gh-93418. + ]) + + self.assertAllRaise(SyntaxError, + "f-string: expecting a valid expression after '{'", + ["f'{3:{{>10}'", + ]) # But these are just normal strings. - self.assertEqual(f'{"{"}', "{") - self.assertEqual(f'{"}"}', "}") - self.assertEqual(f'{3:{"}"}>10}', "}}}}}}}}}3") - self.assertEqual(f'{2:{"{"}>10}', "{{{{{{{{{2") + self.assertEqual(f'{"{"}', '{') + self.assertEqual(f'{"}"}', '}') + self.assertEqual(f'{3:{"}"}>10}', '}}}}}}}}}3') + self.assertEqual(f'{2:{"{"}>10}', '{{{{{{{{{2') def test_if_conditional(self): # There's special logic in compile.c to test if the @@ -1593,7 +1465,7 @@ def test_if_conditional(self): def test_fstring(x, expected): flag = 0 - if f"{x}": + if f'{x}': flag = 1 else: flag = 2 @@ -1601,7 +1473,7 @@ def test_fstring(x, expected): def test_concat_empty(x, expected): flag = 0 - if "" f"{x}": + if '' f'{x}': flag = 1 else: flag = 2 @@ -1609,151 +1481,139 @@ def test_concat_empty(x, expected): def test_concat_non_empty(x, expected): flag = 0 - if " " f"{x}": + if ' ' f'{x}': flag = 1 else: flag = 2 self.assertEqual(flag, expected) - test_fstring("", 2) - test_fstring(" ", 1) + test_fstring('', 2) + test_fstring(' ', 1) - test_concat_empty("", 2) - test_concat_empty(" ", 1) + test_concat_empty('', 2) + test_concat_empty(' ', 1) - test_concat_non_empty("", 1) - test_concat_non_empty(" ", 1) + test_concat_non_empty('', 1) + test_concat_non_empty(' ', 1) def test_empty_format_specifier(self): - x = "test" - self.assertEqual(f"{x}", "test") - self.assertEqual(f"{x:}", "test") - self.assertEqual(f"{x!s:}", "test") - self.assertEqual(f"{x!r:}", "'test'") + x = 'test' + self.assertEqual(f'{x}', 'test') + self.assertEqual(f'{x:}', 'test') + self.assertEqual(f'{x!s:}', 'test') + self.assertEqual(f'{x!r:}', "'test'") def test_str_format_differences(self): - d = { - "a": "string", - 0: "integer", - } + d = {'a': 'string', + 0: 'integer', + } a = 0 - self.assertEqual(f"{d[0]}", "integer") - self.assertEqual(f'{d["a"]}', "string") - self.assertEqual(f"{d[a]}", "integer") - self.assertEqual("{d[a]}".format(d=d), "string") - self.assertEqual("{d[0]}".format(d=d), "integer") - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(f'{d[0]}', 'integer') + self.assertEqual(f'{d["a"]}', 'string') + self.assertEqual(f'{d[a]}', 'integer') + self.assertEqual('{d[a]}'.format(d=d), 'string') + self.assertEqual('{d[0]}'.format(d=d), 'integer') + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_errors(self): # see issue 26287 - self.assertAllRaise( - TypeError, - "unsupported", - [ - r"f'{(lambda: 0):x}'", - r"f'{(0,):x}'", - ], - ) - self.assertAllRaise( - ValueError, - "Unknown format code", - [ - r"f'{1000:j}'", - r"f'{1000:j}'", - ], - ) + self.assertAllRaise(TypeError, 'unsupported', + [r"f'{(lambda: 0):x}'", + r"f'{(0,):x}'", + ]) + self.assertAllRaise(ValueError, 'Unknown format code', + [r"f'{1000:j}'", + r"f'{1000:j}'", + ]) def test_filename_in_syntaxerror(self): # see issue 38964 with temp_cwd() as cwd: - file_path = os.path.join(cwd, "t.py") - with open(file_path, "w", encoding="utf-8") as f: - f.write('f"{a b}"') # This generates a SyntaxError - _, _, stderr = assert_python_failure(file_path, PYTHONIOENCODING="ascii") - self.assertIn(file_path.encode("ascii", "backslashreplace"), stderr) + file_path = os.path.join(cwd, 't.py') + with open(file_path, 'w', encoding="utf-8") as f: + f.write('f"{a b}"') # This generates a SyntaxError + _, _, stderr = assert_python_failure(file_path, + PYTHONIOENCODING='ascii') + self.assertIn(file_path.encode('ascii', 'backslashreplace'), stderr) def test_loop(self): for i in range(1000): - self.assertEqual(f"i:{i}", "i:" + str(i)) + self.assertEqual(f'i:{i}', 'i:' + str(i)) def test_dict(self): - d = { - '"': "dquote", - "'": "squote", - "foo": "bar", - } - self.assertEqual(f"""{d["'"]}""", "squote") - self.assertEqual(f"""{d['"']}""", "dquote") + d = {'"': 'dquote', + "'": 'squote', + 'foo': 'bar', + } + self.assertEqual(f'''{d["'"]}''', 'squote') + self.assertEqual(f"""{d['"']}""", 'dquote') - self.assertEqual(f'{d["foo"]}', "bar") - self.assertEqual(f"{d['foo']}", "bar") + self.assertEqual(f'{d["foo"]}', 'bar') + self.assertEqual(f"{d['foo']}", 'bar') def test_backslash_char(self): # Check eval of a backslash followed by a control char. # See bpo-30682: this used to raise an assert in pydebug mode. - self.assertEqual(eval('f"\\\n"'), "") - self.assertEqual(eval('f"\\\r"'), "") + self.assertEqual(eval('f"\\\n"'), '') + self.assertEqual(eval('f"\\\r"'), '') + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: '1+2 = # my comment\n 3' != '1+2 = \n 3' def test_debug_conversion(self): - x = "A string" - self.assertEqual(f"{x=}", "x=" + repr(x)) - self.assertEqual(f"{x =}", "x =" + repr(x)) - self.assertEqual(f"{x=!s}", "x=" + str(x)) - self.assertEqual(f"{x=!r}", "x=" + repr(x)) - self.assertEqual(f"{x=!a}", "x=" + ascii(x)) + x = 'A string' + self.assertEqual(f'{x=}', 'x=' + repr(x)) + self.assertEqual(f'{x =}', 'x =' + repr(x)) + self.assertEqual(f'{x=!s}', 'x=' + str(x)) + self.assertEqual(f'{x=!r}', 'x=' + repr(x)) + self.assertEqual(f'{x=!a}', 'x=' + ascii(x)) x = 2.71828 - self.assertEqual(f"{x=:.2f}", "x=" + format(x, ".2f")) - self.assertEqual(f"{x=:}", "x=" + format(x, "")) - self.assertEqual(f"{x=!r:^20}", "x=" + format(repr(x), "^20")) - self.assertEqual(f"{x=!s:^20}", "x=" + format(str(x), "^20")) - self.assertEqual(f"{x=!a:^20}", "x=" + format(ascii(x), "^20")) + self.assertEqual(f'{x=:.2f}', 'x=' + format(x, '.2f')) + self.assertEqual(f'{x=:}', 'x=' + format(x, '')) + self.assertEqual(f'{x=!r:^20}', 'x=' + format(repr(x), '^20')) + self.assertEqual(f'{x=!s:^20}', 'x=' + format(str(x), '^20')) + self.assertEqual(f'{x=!a:^20}', 'x=' + format(ascii(x), '^20')) x = 9 - self.assertEqual(f"{3*x+15=}", "3*x+15=42") + self.assertEqual(f'{3*x+15=}', '3*x+15=42') # There is code in ast.c that deals with non-ascii expression values. So, # use a unicode identifier to trigger that. tenπ = 31.4 - self.assertEqual(f"{tenπ=:.2f}", "tenπ=31.40") + self.assertEqual(f'{tenπ=:.2f}', 'tenπ=31.40') # Also test with Unicode in non-identifiers. - self.assertEqual(f'{"Σ"=}', "\"Σ\"='Σ'") + self.assertEqual(f'{"Σ"=}', '"Σ"=\'Σ\'') # Make sure nested fstrings still work. - self.assertEqual(f'{f"{3.1415=:.1f}":*^20}', "*****3.1415=3.1*****") + self.assertEqual(f'{f"{3.1415=:.1f}":*^20}', '*****3.1415=3.1*****') # Make sure text before and after an expression with = works # correctly. - pi = "π" - self.assertEqual(f"alpha α {pi=} ω omega", "alpha α pi='π' ω omega") + pi = 'π' + self.assertEqual(f'alpha α {pi=} ω omega', "alpha α pi='π' ω omega") # Check multi-line expressions. - self.assertEqual( - f"""{ + self.assertEqual(f'''{ 3 -=}""", - "\n3\n=3", - ) +=}''', '\n3\n=3') # Since = is handled specially, make sure all existing uses of # it still work. - self.assertEqual(f"{0==1}", "False") - self.assertEqual(f"{0!=1}", "True") - self.assertEqual(f"{0<=1}", "True") - self.assertEqual(f"{0>=1}", "False") - self.assertEqual(f'{(x:="5")}', "5") - self.assertEqual(x, "5") - self.assertEqual(f"{(x:=5)}", "5") + self.assertEqual(f'{0==1}', 'False') + self.assertEqual(f'{0!=1}', 'True') + self.assertEqual(f'{0<=1}', 'True') + self.assertEqual(f'{0>=1}', 'False') + self.assertEqual(f'{(x:="5")}', '5') + self.assertEqual(x, '5') + self.assertEqual(f'{(x:=5)}', '5') self.assertEqual(x, 5) - self.assertEqual(f'{"="}', "=") + self.assertEqual(f'{"="}', '=') x = 20 # This isn't an assignment expression, it's 'x', with a format # spec of '=10'. See test_walrus: you need to use parens. - self.assertEqual(f"{x:=10}", " 20") + self.assertEqual(f'{x:=10}', ' 20') # Test named function parameters, to make sure '=' parsing works # there. @@ -1762,54 +1622,60 @@ def f(a): oldx = x x = a return oldx - x = 0 - self.assertEqual(f'{f(a="3=")}', "0") - self.assertEqual(x, "3=") - self.assertEqual(f"{f(a=4)}", "3=") + self.assertEqual(f'{f(a="3=")}', '0') + self.assertEqual(x, '3=') + self.assertEqual(f'{f(a=4)}', '3=') self.assertEqual(x, 4) # Check debug expressions in format spec y = 20 self.assertEqual(f"{2:{y=}}", "yyyyyyyyyyyyyyyyyyy2") - self.assertEqual( - f"{datetime.datetime.now():h1{y=}h2{y=}h3{y=}}", "h1y=20h2y=20h3y=20" - ) + self.assertEqual(f"{datetime.datetime.now():h1{y=}h2{y=}h3{y=}}", + 'h1y=20h2y=20h3y=20') # Make sure __format__ is being called. class C: def __format__(self, s): - return f"FORMAT-{s}" - + return f'FORMAT-{s}' def __repr__(self): - return "REPR" + return 'REPR' - self.assertEqual(f"{C()=}", "C()=REPR") - self.assertEqual(f"{C()=!r}", "C()=REPR") - self.assertEqual(f"{C()=:}", "C()=FORMAT-") - self.assertEqual(f"{C()=: }", "C()=FORMAT- ") - self.assertEqual(f"{C()=:x}", "C()=FORMAT-x") - self.assertEqual(f"{C()=!r:*^20}", "C()=********REPR********") - self.assertEqual(f"{C():{20=}}", "FORMAT-20=20") + self.assertEqual(f'{C()=}', 'C()=REPR') + self.assertEqual(f'{C()=!r}', 'C()=REPR') + self.assertEqual(f'{C()=:}', 'C()=FORMAT-') + self.assertEqual(f'{C()=: }', 'C()=FORMAT- ') + self.assertEqual(f'{C()=:x}', 'C()=FORMAT-x') + self.assertEqual(f'{C()=!r:*^20}', 'C()=********REPR********') + self.assertEqual(f"{C():{20=}}", 'FORMAT-20=20') self.assertRaises(SyntaxError, eval, "f'{C=]'") + # Make sure leading and following text works. - x = "foo" - self.assertEqual(f"X{x=}Y", "Xx=" + repr(x) + "Y") + x = 'foo' + self.assertEqual(f'X{x=}Y', 'Xx='+repr(x)+'Y') # Make sure whitespace around the = works. - self.assertEqual(f"X{x =}Y", "Xx =" + repr(x) + "Y") - self.assertEqual(f"X{x= }Y", "Xx= " + repr(x) + "Y") - self.assertEqual(f"X{x = }Y", "Xx = " + repr(x) + "Y") + self.assertEqual(f'X{x =}Y', 'Xx ='+repr(x)+'Y') + self.assertEqual(f'X{x= }Y', 'Xx= '+repr(x)+'Y') + self.assertEqual(f'X{x = }Y', 'Xx = '+repr(x)+'Y') self.assertEqual(f"sadsd {1 + 1 = :{1 + 1:1d}f}", "sadsd 1 + 1 = 2.000000") -# TODO: RUSTPYTHON SyntaxError -# self.assertEqual( -# f"{1+2 = # my comment -# }", -# "1+2 = \n 3", -# ) + self.assertEqual(f"{1+2 = # my comment + }", '1+2 = \n 3') + + self.assertEqual(f'{""" # booo + """=}', '""" # booo\n """=\' # booo\\n \'') + + self.assertEqual(f'{" # nooo "=}', '" # nooo "=\' # nooo \'') + self.assertEqual(f'{" \" # nooo \" "=}', '" \\" # nooo \\" "=\' " # nooo " \'') + + self.assertEqual(f'{ # some comment goes here + """hello"""=}', ' \n """hello"""=\'hello\'') + self.assertEqual(f'{"""# this is not a comment + a""" # this is a comment + }', '# this is not a comment\n a') # These next lines contains tabs. Backslash escapes don't # work in f-strings. @@ -1817,65 +1683,68 @@ def __repr__(self): # this will be to dynamically created and exec the f-strings. But # that's such a hassle I'll save it for another day. For now, convert # the tabs to spaces just to shut up patchcheck. - # self.assertEqual(f'X{x =}Y', 'Xx\t='+repr(x)+'Y') - # self.assertEqual(f'X{x = }Y', 'Xx\t=\t'+repr(x)+'Y') + #self.assertEqual(f'X{x =}Y', 'Xx\t='+repr(x)+'Y') + #self.assertEqual(f'X{x = }Y', 'Xx\t=\t'+repr(x)+'Y') + + def test_debug_expressions_are_raw_strings(self): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', SyntaxWarning) + self.assertEqual(eval("""f'{b"\\N{OX}"=}'"""), 'b"\\N{OX}"=b\'\\\\N{OX}\'') + self.assertEqual(f'{r"\xff"=}', 'r"\\xff"=\'\\\\xff\'') + self.assertEqual(f'{r"\n"=}', 'r"\\n"=\'\\\\n\'') + self.assertEqual(f"{'\''=}", "'\\''=\"'\"") + self.assertEqual(f'{'\xc5'=}', r"'\xc5'='Å'") def test_walrus(self): x = 20 # This isn't an assignment expression, it's 'x', with a format # spec of '=10'. - self.assertEqual(f"{x:=10}", " 20") + self.assertEqual(f'{x:=10}', ' 20') # This is an assignment expression, which requires parens. - self.assertEqual(f"{(x:=10)}", "10") + self.assertEqual(f'{(x:=10)}', '10') self.assertEqual(x, 10) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_syntax_error_message(self): - with self.assertRaisesRegex( - SyntaxError, "f-string: expecting '=', or '!', or ':', or '}'" - ): + with self.assertRaisesRegex(SyntaxError, + "f-string: expecting '=', or '!', or ':', or '}'"): compile("f'{a $ b}'", "?", "exec") def test_with_two_commas_in_format_specifier(self): error_msg = re.escape("Cannot specify ',' with ','.") with self.assertRaisesRegex(ValueError, error_msg): - f"{1:,,}" + f'{1:,,}' def test_with_two_underscore_in_format_specifier(self): error_msg = re.escape("Cannot specify '_' with '_'.") with self.assertRaisesRegex(ValueError, error_msg): - f"{1:__}" + f'{1:__}' def test_with_a_commas_and_an_underscore_in_format_specifier(self): error_msg = re.escape("Cannot specify both ',' and '_'.") with self.assertRaisesRegex(ValueError, error_msg): - f"{1:,_}" + f'{1:,_}' def test_with_an_underscore_and_a_comma_in_format_specifier(self): error_msg = re.escape("Cannot specify both ',' and '_'.") with self.assertRaisesRegex(ValueError, error_msg): - f"{1:_,}" + f'{1:_,}' - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_syntax_error_for_starred_expressions(self): with self.assertRaisesRegex(SyntaxError, "can't use starred expression here"): compile("f'{*a}'", "?", "exec") - with self.assertRaisesRegex( - SyntaxError, "f-string: expecting a valid expression after '{'" - ): + with self.assertRaisesRegex(SyntaxError, + "f-string: expecting a valid expression after '{'"): compile("f'{**a}'", "?", "exec") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_not_closing_quotes(self): self.assertAllRaise(SyntaxError, "unterminated f-string literal", ['f"', "f'"]) - self.assertAllRaise( - SyntaxError, "unterminated triple-quoted f-string literal", ['f"""', "f'''"] - ) + self.assertAllRaise(SyntaxError, "unterminated triple-quoted f-string literal", + ['f"""', "f'''"]) # Ensure that the errors are reported at the correct line number. data = '''\ x = 1 + 1 @@ -1891,56 +1760,126 @@ def test_not_closing_quotes(self): except SyntaxError as e: self.assertEqual(e.text, 'z = f"""') self.assertEqual(e.lineno, 3) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_syntax_error_after_debug(self): - self.assertAllRaise( - SyntaxError, - "f-string: expecting a valid expression after '{'", - [ - "f'{1=}{;'", - "f'{1=}{+;'", - "f'{1=}{2}{;'", - "f'{1=}{3}{;'", - ], - ) - self.assertAllRaise( - SyntaxError, - "f-string: expecting '=', or '!', or ':', or '}'", - [ - "f'{1=}{1;'", - "f'{1=}{1;}'", - ], - ) + self.assertAllRaise(SyntaxError, "f-string: expecting a valid expression after '{'", + [ + "f'{1=}{;'", + "f'{1=}{+;'", + "f'{1=}{2}{;'", + "f'{1=}{3}{;'", + ]) + self.assertAllRaise(SyntaxError, "f-string: expecting '=', or '!', or ':', or '}'", + [ + "f'{1=}{1;'", + "f'{1=}{1;}'", + ]) def test_debug_in_file(self): with temp_cwd(): - script = "script.py" - with open("script.py", "w") as f: + script = 'script.py' + with open('script.py', 'w') as f: f.write(f"""\ print(f'''{{ 3 =}}''')""") _, stdout, _ = assert_python_ok(script) - self.assertEqual( - stdout.decode("utf-8").strip().replace("\r\n", "\n").replace("\r", "\n"), - "3\n=3", - ) + self.assertEqual(stdout.decode('utf-8').strip().replace('\r\n', '\n').replace('\r', '\n'), + "3\n=3") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_syntax_warning_infinite_recursion_in_file(self): with temp_cwd(): - script = "script.py" - with open(script, "w") as f: + script = 'script.py' + with open(script, 'w') as f: f.write(r"print(f'\{1}')") _, stdout, stderr = assert_python_ok(script) - self.assertIn(rb"\1", stdout) + self.assertIn(rb'\1', stdout) self.assertEqual(len(stderr.strip().splitlines()), 2) + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'dis' has no attribute 'get_instructions' + def test_fstring_without_formatting_bytecode(self): + # f-string without any formatting should emit the same bytecode + # as a normal string. See gh-99606. + def get_code(s): + return [(i.opname, i.oparg) for i in dis.get_instructions(s)] + + for s in ["", "some string"]: + self.assertEqual(get_code(f"'{s}'"), get_code(f"f'{s}'")) + + def test_gh129093(self): + self.assertEqual(f'{1==2=}', '1==2=False') + self.assertEqual(f'{1 == 2=}', '1 == 2=False') + self.assertEqual(f'{1!=2=}', '1!=2=True') + self.assertEqual(f'{1 != 2=}', '1 != 2=True') + + self.assertEqual(f'{(1) != 2=}', '(1) != 2=True') + self.assertEqual(f'{(1*2) != (3)=}', '(1*2) != (3)=True') + + self.assertEqual(f'{1 != 2 == 3 != 4=}', '1 != 2 == 3 != 4=False') + self.assertEqual(f'{1 == 2 != 3 == 4=}', '1 == 2 != 3 == 4=False') + + self.assertEqual(f'{f'{1==2=}'=}', "f'{1==2=}'='1==2=False'") + self.assertEqual(f'{f'{1 == 2=}'=}', "f'{1 == 2=}'='1 == 2=False'") + self.assertEqual(f'{f'{1!=2=}'=}', "f'{1!=2=}'='1!=2=True'") + self.assertEqual(f'{f'{1 != 2=}'=}', "f'{1 != 2=}'='1 != 2=True'") + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "f-string: newlines are not allowed in format specifiers" does not match "'unexpected EOF while parsing' (, line 2)" + def test_newlines_in_format_specifiers(self): + cases = [ + """f'{1:d\n}'""", + """f'__{ + 1:d + }__'""", + '''f"{value:. + {'2f'}}"''', + '''f"{value: + {'.2f'}f}"''', + '''f"{value: + #{'x'}}"''', + ] + self.assertAllRaise(SyntaxError, "f-string: newlines are not allowed in format specifiers", cases) + + valid_cases = [ + """f'''__{ + 1:d + }__'''""", + """f'''{1:d\n}'''""", + ] + + for case in valid_cases: + compile(case, "", "exec") + + def test_raw_fstring_format_spec(self): + # Test raw f-string format spec behavior (Issue #137314). + # + # Raw f-strings should preserve literal backslashes in format specifications, + # not interpret them as escape sequences. + class UnchangedFormat: + """Test helper that returns the format spec unchanged.""" + def __format__(self, format): + return format + + # Test basic escape sequences + self.assertEqual(f"{UnchangedFormat():\xFF}", 'ÿ') + self.assertEqual(rf"{UnchangedFormat():\xFF}", '\\xFF') + + # Test nested expressions with raw/non-raw combinations + self.assertEqual(rf"{UnchangedFormat():{'\xFF'}}", 'ÿ') + self.assertEqual(f"{UnchangedFormat():{r'\xFF'}}", '\\xFF') + self.assertEqual(rf"{UnchangedFormat():{r'\xFF'}}", '\\xFF') + + # Test continuation character in format specs + self.assertEqual(f"""{UnchangedFormat():{'a'\ + 'b'}}""", 'ab') + self.assertEqual(rf"""{UnchangedFormat():{'a'\ + 'b'}}""", 'ab') + + # Test multiple format specs in same raw f-string + self.assertEqual(rf"{UnchangedFormat():\xFF} {UnchangedFormat():\n}", '\\xFF \\n') + -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py index e11ea81a7d..c3fb8939a6 100644 --- a/Lib/test/test_glob.py +++ b/Lib/test/test_glob.py @@ -1,9 +1,12 @@ import glob import os +import re import shutil import sys import unittest +import warnings +from test.support import is_wasi, Py_DEBUG from test.support.os_helper import (TESTFN, skip_unless_symlink, can_symlink, create_empty_file, change_cwd) @@ -167,37 +170,45 @@ def test_glob_directory_names(self): self.norm('aab', 'F')]) def test_glob_directory_with_trailing_slash(self): - # Patterns ending with a slash shouldn't match non-dirs - res = glob.glob(self.norm('Z*Z') + os.sep) - self.assertEqual(res, []) - res = glob.glob(self.norm('ZZZ') + os.sep) - self.assertEqual(res, []) - # When there is a wildcard pattern which ends with os.sep, glob() - # doesn't blow up. - res = glob.glob(self.norm('aa*') + os.sep) - self.assertEqual(len(res), 2) - # either of these results is reasonable - self.assertIn(set(res), [ - {self.norm('aaa'), self.norm('aab')}, - {self.norm('aaa') + os.sep, self.norm('aab') + os.sep}, - ]) + seps = (os.sep, os.altsep) if os.altsep else (os.sep,) + for sep in seps: + # Patterns ending with a slash shouldn't match non-dirs + self.assertEqual(glob.glob(self.norm('Z*Z') + sep), []) + self.assertEqual(glob.glob(self.norm('ZZZ') + sep), []) + self.assertEqual(glob.glob(self.norm('aaa') + sep), + [self.norm('aaa') + sep]) + # Preserving the redundant separators is an implementation detail. + self.assertEqual(glob.glob(self.norm('aaa') + sep*2), + [self.norm('aaa') + sep*2]) + # When there is a wildcard pattern which ends with a pathname + # separator, glob() doesn't blow. + # The result should end with the pathname separator. + # Normalizing the trailing separator is an implementation detail. + eq = self.assertSequencesEqual_noorder + eq(glob.glob(self.norm('aa*') + sep), + [self.norm('aaa') + os.sep, self.norm('aab') + os.sep]) + # Stripping the redundant separators is an implementation detail. + eq(glob.glob(self.norm('aa*') + sep*2), + [self.norm('aaa') + os.sep, self.norm('aab') + os.sep]) def test_glob_bytes_directory_with_trailing_slash(self): # Same as test_glob_directory_with_trailing_slash, but with a # bytes argument. - res = glob.glob(os.fsencode(self.norm('Z*Z') + os.sep)) - self.assertEqual(res, []) - res = glob.glob(os.fsencode(self.norm('ZZZ') + os.sep)) - self.assertEqual(res, []) - res = glob.glob(os.fsencode(self.norm('aa*') + os.sep)) - self.assertEqual(len(res), 2) - # either of these results is reasonable - self.assertIn(set(res), [ - {os.fsencode(self.norm('aaa')), - os.fsencode(self.norm('aab'))}, - {os.fsencode(self.norm('aaa') + os.sep), - os.fsencode(self.norm('aab') + os.sep)}, - ]) + seps = (os.sep, os.altsep) if os.altsep else (os.sep,) + for sep in seps: + self.assertEqual(glob.glob(os.fsencode(self.norm('Z*Z') + sep)), []) + self.assertEqual(glob.glob(os.fsencode(self.norm('ZZZ') + sep)), []) + self.assertEqual(glob.glob(os.fsencode(self.norm('aaa') + sep)), + [os.fsencode(self.norm('aaa') + sep)]) + self.assertEqual(glob.glob(os.fsencode(self.norm('aaa') + sep*2)), + [os.fsencode(self.norm('aaa') + sep*2)]) + eq = self.assertSequencesEqual_noorder + eq(glob.glob(os.fsencode(self.norm('aa*') + sep)), + [os.fsencode(self.norm('aaa') + os.sep), + os.fsencode(self.norm('aab') + os.sep)]) + eq(glob.glob(os.fsencode(self.norm('aa*') + sep*2)), + [os.fsencode(self.norm('aaa') + os.sep), + os.fsencode(self.norm('aab') + os.sep)]) @skip_unless_symlink def test_glob_symlinks(self): @@ -205,8 +216,7 @@ def test_glob_symlinks(self): eq(self.glob('sym3'), [self.norm('sym3')]) eq(self.glob('sym3', '*'), [self.norm('sym3', 'EF'), self.norm('sym3', 'efg')]) - self.assertIn(self.glob('sym3' + os.sep), - [[self.norm('sym3')], [self.norm('sym3') + os.sep]]) + eq(self.glob('sym3' + os.sep), [self.norm('sym3') + os.sep]) eq(self.glob('*', '*F'), [self.norm('aaa', 'zzzF'), self.norm('aab', 'F'), self.norm('sym3', 'EF')]) @@ -364,6 +374,8 @@ def test_glob_named_pipe(self): self.assertEqual(self.rglob('mypipe', 'sub'), []) self.assertEqual(self.rglob('mypipe', '*'), []) + + @unittest.skipIf(is_wasi and Py_DEBUG, "requires too much stack") def test_glob_many_open_files(self): depth = 30 base = os.path.join(self.tempdir, 'deep') @@ -381,10 +393,134 @@ def test_glob_many_open_files(self): for it in iters: self.assertEqual(next(it), p) + def test_glob0(self): + with self.assertWarns(DeprecationWarning): + glob.glob0(self.tempdir, 'a') + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + eq = self.assertSequencesEqual_noorder + eq(glob.glob0(self.tempdir, 'a'), ['a']) + eq(glob.glob0(self.tempdir, '.bb'), ['.bb']) + eq(glob.glob0(self.tempdir, '.b*'), []) + eq(glob.glob0(self.tempdir, 'b'), []) + eq(glob.glob0(self.tempdir, '?'), []) + eq(glob.glob0(self.tempdir, '*a'), []) + eq(glob.glob0(self.tempdir, 'a*'), []) + + def test_glob1(self): + with self.assertWarns(DeprecationWarning): + glob.glob1(self.tempdir, 'a') + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + eq = self.assertSequencesEqual_noorder + eq(glob.glob1(self.tempdir, 'a'), ['a']) + eq(glob.glob1(self.tempdir, '.bb'), ['.bb']) + eq(glob.glob1(self.tempdir, '.b*'), ['.bb']) + eq(glob.glob1(self.tempdir, 'b'), []) + eq(glob.glob1(self.tempdir, '?'), ['a']) + eq(glob.glob1(self.tempdir, '*a'), ['a', 'aaa']) + eq(glob.glob1(self.tempdir, 'a*'), ['a', 'aaa', 'aab']) + + def test_translate_matching(self): + match = re.compile(glob.translate('*')).match + self.assertIsNotNone(match('foo')) + self.assertIsNotNone(match('foo.bar')) + self.assertIsNone(match('.foo')) + match = re.compile(glob.translate('.*')).match + self.assertIsNotNone(match('.foo')) + match = re.compile(glob.translate('**', recursive=True)).match + self.assertIsNotNone(match('foo')) + self.assertIsNone(match('.foo')) + self.assertIsNotNone(match(os.path.join('foo', 'bar'))) + self.assertIsNone(match(os.path.join('foo', '.bar'))) + self.assertIsNone(match(os.path.join('.foo', 'bar'))) + self.assertIsNone(match(os.path.join('.foo', '.bar'))) + match = re.compile(glob.translate('**/*', recursive=True)).match + self.assertIsNotNone(match(os.path.join('foo', 'bar'))) + self.assertIsNone(match(os.path.join('foo', '.bar'))) + self.assertIsNone(match(os.path.join('.foo', 'bar'))) + self.assertIsNone(match(os.path.join('.foo', '.bar'))) + match = re.compile(glob.translate('*/**', recursive=True)).match + self.assertIsNotNone(match(os.path.join('foo', 'bar'))) + self.assertIsNone(match(os.path.join('foo', '.bar'))) + self.assertIsNone(match(os.path.join('.foo', 'bar'))) + self.assertIsNone(match(os.path.join('.foo', '.bar'))) + match = re.compile(glob.translate('**/.bar', recursive=True)).match + self.assertIsNotNone(match(os.path.join('foo', '.bar'))) + self.assertIsNone(match(os.path.join('.foo', '.bar'))) + match = re.compile(glob.translate('**/*.*', recursive=True)).match + self.assertIsNone(match(os.path.join('foo', 'bar'))) + self.assertIsNone(match(os.path.join('foo', '.bar'))) + self.assertIsNotNone(match(os.path.join('foo', 'bar.txt'))) + self.assertIsNone(match(os.path.join('foo', '.bar.txt'))) + + def test_translate(self): + def fn(pat): + return glob.translate(pat, seps='/') + self.assertEqual(fn('foo'), r'(?s:foo)\Z') + self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\Z') + self.assertEqual(fn('*'), r'(?s:[^/.][^/]*)\Z') + self.assertEqual(fn('?'), r'(?s:(?!\.)[^/])\Z') + self.assertEqual(fn('a*'), r'(?s:a[^/]*)\Z') + self.assertEqual(fn('*a'), r'(?s:(?!\.)[^/]*a)\Z') + self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\Z') + self.assertEqual(fn('?aa'), r'(?s:(?!\.)[^/]aa)\Z') + self.assertEqual(fn('aa?'), r'(?s:aa[^/])\Z') + self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\Z') + self.assertEqual(fn('**'), r'(?s:(?!\.)[^/]*)\Z') + self.assertEqual(fn('***'), r'(?s:(?!\.)[^/]*)\Z') + self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z') + self.assertEqual(fn('**b'), r'(?s:(?!\.)[^/]*b)\Z') + self.assertEqual(fn('/**/*/*.*/**'), + r'(?s:/(?!\.)[^/]*/[^/.][^/]*/(?!\.)[^/]*\.[^/]*/(?!\.)[^/]*)\Z') + + def test_translate_include_hidden(self): + def fn(pat): + return glob.translate(pat, include_hidden=True, seps='/') + self.assertEqual(fn('foo'), r'(?s:foo)\Z') + self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\Z') + self.assertEqual(fn('*'), r'(?s:[^/]+)\Z') + self.assertEqual(fn('?'), r'(?s:[^/])\Z') + self.assertEqual(fn('a*'), r'(?s:a[^/]*)\Z') + self.assertEqual(fn('*a'), r'(?s:[^/]*a)\Z') + self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\Z') + self.assertEqual(fn('?aa'), r'(?s:[^/]aa)\Z') + self.assertEqual(fn('aa?'), r'(?s:aa[^/])\Z') + self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\Z') + self.assertEqual(fn('**'), r'(?s:[^/]*)\Z') + self.assertEqual(fn('***'), r'(?s:[^/]*)\Z') + self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z') + self.assertEqual(fn('**b'), r'(?s:[^/]*b)\Z') + self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/[^/]*/[^/]+/[^/]*\.[^/]*/[^/]*)\Z') + + def test_translate_recursive(self): + def fn(pat): + return glob.translate(pat, recursive=True, include_hidden=True, seps='/') + self.assertEqual(fn('*'), r'(?s:[^/]+)\Z') + self.assertEqual(fn('?'), r'(?s:[^/])\Z') + self.assertEqual(fn('**'), r'(?s:.*)\Z') + self.assertEqual(fn('**/**'), r'(?s:.*)\Z') + self.assertEqual(fn('***'), r'(?s:[^/]*)\Z') + self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z') + self.assertEqual(fn('**b'), r'(?s:[^/]*b)\Z') + self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/(?:.+/)?[^/]+/[^/]*\.[^/]*/.*)\Z') + + def test_translate_seps(self): + def fn(pat): + return glob.translate(pat, recursive=True, include_hidden=True, seps=['/', '\\']) + self.assertEqual(fn('foo/bar\\baz'), r'(?s:foo[/\\]bar[/\\]baz)\Z') + self.assertEqual(fn('**/*'), r'(?s:(?:.+[/\\])?[^/\\]+)\Z') + @skip_unless_symlink class SymlinkLoopGlobTests(unittest.TestCase): + # gh-109959: On Linux, glob._isdir() and glob._lexists() can return False + # randomly when checking the "link/" symbolic link. + # https://github.com/python/cpython/issues/109959#issuecomment-2577550700 + @unittest.skip("flaky test") def test_selflink(self): tempdir = TESTFN + "_dir" os.makedirs(tempdir) diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index e91d454b72..272098782d 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -39,11 +39,9 @@ from test import support from test.support.script_helper import ( assert_python_ok, assert_python_failure, run_python_until_end) -from test.support import import_helper -from test.support import os_helper -from test.support import threading_helper -from test.support import warnings_helper -from test.support import skip_if_sanitizer +from test.support import ( + import_helper, is_apple, os_helper, threading_helper, warnings_helper, +) from test.support.os_helper import FakePath import codecs @@ -66,10 +64,6 @@ def byteslike(*pos, **kw): class EmptyStruct(ctypes.Structure): pass -# Does io.IOBase finalizer log the exception if the close() method fails? -# The exception is ignored silently by default in release build. -IOBASE_EMITS_UNRAISABLE = (support.Py_DEBUG or sys.flags.dev_mode) - def _default_chunk_size(): """Get the default TextIOWrapper chunk size""" @@ -631,10 +625,10 @@ def test_raw_bytes_io(self): self.read_ops(f, True) def test_large_file_ops(self): - # On Windows and Mac OSX this test consumes large resources; It takes - # a long time to build the >2 GiB file and takes >2 GiB of disk space - # therefore the resource must be enabled to run this test. - if sys.platform[:3] == 'win' or sys.platform == 'darwin': + # On Windows and Apple platforms this test consumes large resources; It + # takes a long time to build the >2 GiB file and takes >2 GiB of disk + # space therefore the resource must be enabled to run this test. + if sys.platform[:3] == 'win' or is_apple: support.requires( 'largefile', 'test requires %s bytes and a long time to run' % self.LARGE) @@ -645,11 +639,9 @@ def test_large_file_ops(self): def test_with_open(self): for bufsize in (0, 100): - f = None with self.open(os_helper.TESTFN, "wb", bufsize) as f: f.write(b"xxx") self.assertEqual(f.closed, True) - f = None try: with self.open(os_helper.TESTFN, "wb", bufsize) as f: 1/0 @@ -788,8 +780,7 @@ def test_closefd_attr(self): file = self.open(f.fileno(), "r", encoding="utf-8", closefd=False) self.assertEqual(file.buffer.raw.closefd, False) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_garbage_collection(self): # FileIO objects are collected, and collecting them flushes # all data to disk. @@ -904,7 +895,7 @@ def test_bad_opener_negative_1(self): def badopener(fname, flags): return -1 with self.assertRaises(ValueError) as cm: - open('non-existent', 'r', opener=badopener) + self.open('non-existent', 'r', opener=badopener) self.assertEqual(str(cm.exception), 'opener returned -1') def test_bad_opener_other_negative(self): @@ -912,7 +903,7 @@ def test_bad_opener_other_negative(self): def badopener(fname, flags): return -2 with self.assertRaises(ValueError) as cm: - open('non-existent', 'r', opener=badopener) + self.open('non-existent', 'r', opener=badopener) self.assertEqual(str(cm.exception), 'opener returned -2') def test_opener_invalid_fd(self): @@ -1048,11 +1039,41 @@ def flush(self): # Silence destructor error R.flush = lambda self: None + @threading_helper.requires_working_threading() + def test_write_readline_races(self): + # gh-134908: Concurrent iteration over a file caused races + thread_count = 2 + write_count = 100 + read_count = 100 + + def writer(file, barrier): + barrier.wait() + for _ in range(write_count): + file.write("x") + + def reader(file, barrier): + barrier.wait() + for _ in range(read_count): + for line in file: + self.assertEqual(line, "") + + with self.open(os_helper.TESTFN, "w+") as f: + barrier = threading.Barrier(thread_count + 1) + reader = threading.Thread(target=reader, args=(f, barrier)) + writers = [threading.Thread(target=writer, args=(f, barrier)) + for _ in range(thread_count)] + with threading_helper.catch_threading_exception() as cm: + with threading_helper.start_threads(writers + [reader]): + pass + self.assertIsNone(cm.exc_type) + + self.assertEqual(os.stat(os_helper.TESTFN).st_size, + write_count * thread_count) + class CIOTest(IOTest): - # TODO: RUSTPYTHON, cyclic gc - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; cyclic gc def test_IOBase_finalize(self): # Issue #12149: segmentation fault on _PyIOBase_finalize when both a # class which inherits IOBase and an object of this class are caught @@ -1071,10 +1092,9 @@ def close(self): support.gc_collect() self.assertIsNone(wr(), wr) - # TODO: RUSTPYTHON, AssertionError: filter ('', ResourceWarning) did not catch any warning - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: filter ('', ResourceWarning) did not catch any warning def test_destructor(self): - super().test_destructor(self) + return super().test_destructor() @support.cpython_only class TestIOCTypes(unittest.TestCase): @@ -1165,9 +1185,32 @@ def test_disallow_instantiation(self): _io = self._io support.check_disallow_instantiation(self, _io._BytesIOBuffer) + def test_stringio_setstate(self): + # gh-127182: Calling __setstate__() with invalid arguments must not crash + obj = self._io.StringIO() + with self.assertRaisesRegex( + TypeError, + 'initial_value must be str or None, not int', + ): + obj.__setstate__((1, '', 0, {})) + + obj.__setstate__((None, '', 0, {})) # should not crash + self.assertEqual(obj.getvalue(), '') + + obj.__setstate__(('', '', 0, {})) + self.assertEqual(obj.getvalue(), '') + class PyIOTest(IOTest): pass + @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: Negative file descriptor + def test_bad_opener_negative_1(): + return super().test_bad_opener_negative_1() + + @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: Negative file descriptor + def test_bad_opener_other_negative(): + return super().test_bad_opener_other_negative() + @support.cpython_only class APIMismatchTest(unittest.TestCase): @@ -1175,7 +1218,7 @@ class APIMismatchTest(unittest.TestCase): def test_RawIOBase_io_in_pyio_match(self): """Test that pyio RawIOBase class has all c RawIOBase methods""" mismatch = support.detect_api_mismatch(pyio.RawIOBase, io.RawIOBase, - ignore=('__weakref__',)) + ignore=('__weakref__', '__static_attributes__')) self.assertEqual(mismatch, set(), msg='Python RawIOBase does not have all C RawIOBase methods') def test_RawIOBase_pyio_in_io_match(self): @@ -1244,6 +1287,7 @@ def _with(): # a ValueError. self.assertRaises(ValueError, _with) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_through_destructor(self): # Test that the exception state is not modified by a destructor, # even if close() fails. @@ -1252,10 +1296,7 @@ def test_error_through_destructor(self): with self.assertRaises(AttributeError): self.tp(rawio).xyzzy - if not IOBASE_EMITS_UNRAISABLE: - self.assertIsNone(cm.unraisable) - elif cm.unraisable is not None: - self.assertEqual(cm.unraisable.exc_type, OSError) + self.assertEqual(cm.unraisable.exc_type, OSError) def test_repr(self): raw = self.MockRawIO() @@ -1271,11 +1312,9 @@ def test_recursive_repr(self): # Issue #25455 raw = self.MockRawIO() b = self.tp(raw) - with support.swap_attr(raw, 'name', b): - try: + with support.swap_attr(raw, 'name', b), support.infinite_recursion(25): + with self.assertRaises(RuntimeError): repr(b) # Should not crash - except RuntimeError: - pass def test_flush_error_on_close(self): # Test that buffered file is closed despite failed flush @@ -1356,6 +1395,28 @@ def test_readonly_attributes(self): with self.assertRaises(AttributeError): buf.raw = x + def test_pickling_subclass(self): + global MyBufferedIO + class MyBufferedIO(self.tp): + def __init__(self, raw, tag): + super().__init__(raw) + self.tag = tag + def __getstate__(self): + return self.tag, self.raw.getvalue() + def __setstate__(slf, state): + tag, value = state + slf.__init__(self.BytesIO(value), tag) + + raw = self.BytesIO(b'data') + buf = MyBufferedIO(raw, tag='ham') + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(protocol=proto): + pickled = pickle.dumps(buf, proto) + newbuf = pickle.loads(pickled) + self.assertEqual(newbuf.raw.getvalue(), b'data') + self.assertEqual(newbuf.tag, 'ham') + del MyBufferedIO + class SizeofTest: @@ -1717,20 +1778,6 @@ def test_seek_character_device_file(self): class CBufferedReaderTest(BufferedReaderTest, SizeofTest): tp = io.BufferedReader - @unittest.skip("TODO: RUSTPYTHON, fallible allocation") - @skip_if_sanitizer(memory=True, address=True, thread=True, - reason="sanitizer defaults to crashing " - "instead of returning NULL for malloc failure.") - def test_constructor(self): - BufferedReaderTest.test_constructor(self) - # The allocation can succeed on 32-bit builds, e.g. with more - # than 2 GiB RAM and a 64-bit kernel. - if sys.maxsize > 0x7FFFFFFF: - rawio = self.MockRawIO() - bufio = self.tp(rawio) - self.assertRaises((OverflowError, MemoryError, ValueError), - bufio.__init__, rawio, sys.maxsize) - def test_initialization(self): rawio = self.MockRawIO([b"abc"]) bufio = self.tp(rawio) @@ -1748,8 +1795,7 @@ def test_misbehaved_io_read(self): # checking this is not so easy. self.assertRaises(OSError, bufio.read, 10) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_garbage_collection(self): # C BufferedReader objects are collected. # The Python version has __del__, so it ends into gc.garbage instead @@ -1766,40 +1812,44 @@ def test_garbage_collection(self): def test_args_error(self): # Issue #17275 with self.assertRaisesRegex(TypeError, "BufferedReader"): - self.tp(io.BytesIO(), 1024, 1024, 1024) + self.tp(self.BytesIO(), 1024, 1024, 1024) def test_bad_readinto_value(self): - rawio = io.BufferedReader(io.BytesIO(b"12")) + rawio = self.tp(self.BytesIO(b"12")) rawio.readinto = lambda buf: -1 bufio = self.tp(rawio) with self.assertRaises(OSError) as cm: bufio.readline() self.assertIsNone(cm.exception.__cause__) - # TODO: RUSTPYTHON, TypeError: 'bytes' object cannot be interpreted as an integer") - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: 'bytes' object cannot be interpreted as an integer") def test_bad_readinto_type(self): - rawio = io.BufferedReader(io.BytesIO(b"12")) + rawio = self.tp(self.BytesIO(b"12")) rawio.readinto = lambda buf: b'' bufio = self.tp(rawio) with self.assertRaises(OSError) as cm: bufio.readline() self.assertIsInstance(cm.exception.__cause__, TypeError) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_flush_error_on_close(self): - super().test_flush_error_on_close() - - # TODO: RUSTPYTHON, AssertionError: UnsupportedOperation not raised by truncate - @unittest.expectedFailure - def test_truncate_on_read_only(self): # TODO: RUSTPYTHON, remove when this passes - super().test_truncate_on_read_only() # TODO: RUSTPYTHON, remove when this passes + return super().test_flush_error_on_close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_seek_character_device_file(self): - super().test_seek_character_device_file() + return super().test_seek_character_device_file() + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: UnsupportedOperation not raised by truncate + def test_truncate_on_read_only(self): + return super().test_truncate_on_read_only() + + @unittest.skip('TODO: RUSTPYTHON; fallible allocation') + def test_constructor(self): + return super().test_constructor() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickling_subclass(self): + return super().test_pickling_subclass() class PyBufferedReaderTest(BufferedReaderTest): @@ -1909,8 +1959,7 @@ def _seekrel(bufio): def test_writes_and_truncates(self): self.check_writes(lambda bufio: bufio.truncate(bufio.tell())) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_write_non_blocking(self): raw = self.MockNonBlockWriterIO() bufio = self.tp(raw, 8) @@ -2107,20 +2156,6 @@ def test_slow_close_from_thread(self): class CBufferedWriterTest(BufferedWriterTest, SizeofTest): tp = io.BufferedWriter - @unittest.skip("TODO: RUSTPYTHON, fallible allocation") - @skip_if_sanitizer(memory=True, address=True, thread=True, - reason="sanitizer defaults to crashing " - "instead of returning NULL for malloc failure.") - def test_constructor(self): - BufferedWriterTest.test_constructor(self) - # The allocation can succeed on 32-bit builds, e.g. with more - # than 2 GiB RAM and a 64-bit kernel. - if sys.maxsize > 0x7FFFFFFF: - rawio = self.MockRawIO() - bufio = self.tp(rawio) - self.assertRaises((OverflowError, MemoryError, ValueError), - bufio.__init__, rawio, sys.maxsize) - def test_initialization(self): rawio = self.MockRawIO() bufio = self.tp(rawio) @@ -2131,8 +2166,7 @@ def test_initialization(self): self.assertRaises(ValueError, bufio.__init__, rawio, buffer_size=-1) self.assertRaises(ValueError, bufio.write, b"def") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_garbage_collection(self): # C BufferedWriter objects are collected, and collecting them flushes # all data to disk. @@ -2155,11 +2189,17 @@ def test_args_error(self): with self.assertRaisesRegex(TypeError, "BufferedWriter"): self.tp(self.BytesIO(), 1024, 1024, 1024) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_flush_error_on_close(self): - super().test_flush_error_on_close() + return super().test_flush_error_on_close() + + @unittest.skip('TODO: RUSTPYTHON; fallible allocation') + def test_constructor(self): + return super().test_constructor() + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickling_subclass(self): + return super().test_pickling_subclass() class PyBufferedWriterTest(BufferedWriterTest): tp = pyio.BufferedWriter @@ -2637,22 +2677,7 @@ def test_interleaved_readline_write(self): class CBufferedRandomTest(BufferedRandomTest, SizeofTest): tp = io.BufferedRandom - @unittest.skip("TODO: RUSTPYTHON, fallible allocation") - @skip_if_sanitizer(memory=True, address=True, thread=True, - reason="sanitizer defaults to crashing " - "instead of returning NULL for malloc failure.") - def test_constructor(self): - BufferedRandomTest.test_constructor(self) - # The allocation can succeed on 32-bit builds, e.g. with more - # than 2 GiB RAM and a 64-bit kernel. - if sys.maxsize > 0x7FFFFFFF: - rawio = self.MockRawIO() - bufio = self.tp(rawio) - self.assertRaises((OverflowError, MemoryError, ValueError), - bufio.__init__, rawio, sys.maxsize) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_garbage_collection(self): CBufferedReaderTest.test_garbage_collection(self) CBufferedWriterTest.test_garbage_collection(self) @@ -2662,20 +2687,25 @@ def test_args_error(self): with self.assertRaisesRegex(TypeError, "BufferedRandom"): self.tp(self.BytesIO(), 1024, 1024, 1024) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_flush_error_on_close(self): - super().test_flush_error_on_close() + return super().test_flush_error_on_close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_seek_character_device_file(self): - super().test_seek_character_device_file() + return super().test_seek_character_device_file() - # TODO: RUSTPYTHON; f.read1(1) returns b'a' - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; f.read1(1) returns b'a' def test_read1_after_write(self): - super().test_read1_after_write() + return super().test_read1_after_write() + + @unittest.skip('TODO: RUSTPYTHON; fallible allocation') + def test_constructor(self): + return super().test_constructor() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickling_subclass(self): + return super().test_pickling_subclass() class PyBufferedRandomTest(BufferedRandomTest): @@ -2935,11 +2965,16 @@ def test_recursive_repr(self): # Issue #25455 raw = self.BytesIO() t = self.TextIOWrapper(raw, encoding="utf-8") - with support.swap_attr(raw, 'name', t): - try: + with support.swap_attr(raw, 'name', t), support.infinite_recursion(25): + with self.assertRaises(RuntimeError): repr(t) # Should not crash - except RuntimeError: - pass + + def test_subclass_repr(self): + class TestSubclass(self.TextIOWrapper): + pass + + f = TestSubclass(self.StringIO()) + self.assertIn(TestSubclass.__name__, repr(f)) def test_line_buffering(self): r = self.BytesIO() @@ -3166,6 +3201,7 @@ def flush(self): support.gc_collect() self.assertEqual(record, [1, 2, 3]) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_error_through_destructor(self): # Test that the exception state is not modified by a destructor, # even if close() fails. @@ -3174,10 +3210,7 @@ def test_error_through_destructor(self): with self.assertRaises(AttributeError): self.TextIOWrapper(rawio, encoding="utf-8").xyzzy - if not IOBASE_EMITS_UNRAISABLE: - self.assertIsNone(cm.unraisable) - elif cm.unraisable is not None: - self.assertEqual(cm.unraisable.exc_type, OSError) + self.assertEqual(cm.unraisable.exc_type, OSError) # Systematic tests of the text I/O API @@ -3327,8 +3360,7 @@ def test_seek_and_tell_with_data(data, min_pos=0): finally: StatefulIncrementalDecoder.codecEnabled = 0 - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_multibyte_seek_and_tell(self): f = self.open(os_helper.TESTFN, "w", encoding="euc_jp") f.write("AB\n\u3046\u3048\n") @@ -3344,8 +3376,7 @@ def test_multibyte_seek_and_tell(self): self.assertEqual(f.tell(), p1) f.close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_seek_with_encoder_state(self): f = self.open(os_helper.TESTFN, "w", encoding="euc_jis_2004") f.write("\u00e6\u0300") @@ -3359,8 +3390,7 @@ def test_seek_with_encoder_state(self): self.assertEqual(f.readline(), "\u00e6\u0300\u0300") f.close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_encoded_writes(self): data = "1234567890" tests = ("utf-16", @@ -3499,8 +3529,7 @@ def test_issue2282(self): self.assertEqual(buffer.seekable(), txt.seekable()) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_append_bom(self): # The BOM is not written again when appending to a non-empty file filename = os_helper.TESTFN @@ -3516,8 +3545,7 @@ def test_append_bom(self): with self.open(filename, 'rb') as f: self.assertEqual(f.read(), 'aaaxxx'.encode(charset)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_seek_bom(self): # Same test, but when seeking manually filename = os_helper.TESTFN @@ -3533,8 +3561,7 @@ def test_seek_bom(self): with self.open(filename, 'rb') as f: self.assertEqual(f.read(), 'bbbzzz'.encode(charset)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_seek_append_bom(self): # Same test, but first seek to the start and then to the end filename = os_helper.TESTFN @@ -3795,17 +3822,14 @@ def _check_create_at_shutdown(self, **kwargs): codecs.lookup('utf-8') class C: - def __init__(self): - self.buf = io.BytesIO() def __del__(self): - io.TextIOWrapper(self.buf, **{kwargs}) + io.TextIOWrapper(io.BytesIO(), **{kwargs}) print("ok") c = C() """.format(iomod=iomod, kwargs=kwargs) return assert_python_ok("-c", code) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_create_at_shutdown_without_encoding(self): rc, out, err = self._check_create_at_shutdown() if err: @@ -3815,8 +3839,7 @@ def test_create_at_shutdown_without_encoding(self): else: self.assertEqual("ok", out.decode().strip()) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_create_at_shutdown_with_encoding(self): rc, out, err = self._check_create_at_shutdown(encoding='utf-8', errors='strict') @@ -4042,6 +4065,28 @@ def test_issue35928(self): f.write(res) self.assertEqual(res + f.readline(), 'foo\nbar\n') + def test_pickling_subclass(self): + global MyTextIO + class MyTextIO(self.TextIOWrapper): + def __init__(self, raw, tag): + super().__init__(raw) + self.tag = tag + def __getstate__(self): + return self.tag, self.buffer.getvalue() + def __setstate__(slf, state): + tag, value = state + slf.__init__(self.BytesIO(value), tag) + + raw = self.BytesIO(b'data') + txt = MyTextIO(raw, 'ham') + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(protocol=proto): + pickled = pickle.dumps(txt, proto) + newtxt = pickle.loads(pickled) + self.assertEqual(newtxt.buffer.getvalue(), b'data') + self.assertEqual(newtxt.tag, 'ham') + del MyTextIO + class MemviewBytesIO(io.BytesIO): '''A BytesIO object whose read method returns memoryviews @@ -4066,98 +4111,7 @@ class CTextIOWrapperTest(TextIOWrapperTest): io = io shutdown_error = "LookupError: unknown encoding: ascii" - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_constructor(self): - super().test_constructor() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_detach(self): - super().test_detach() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_encoding_read(self): - super().test_reconfigure_encoding_read() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_line_buffering(self): - super().test_reconfigure_line_buffering() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_basic_io(self): - super().test_basic_io() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_telling(self): - super().test_telling() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_uninitialized(self): - super().test_uninitialized() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_non_text_encoding_codecs_are_rejected(self): - super().test_non_text_encoding_codecs_are_rejected() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_repr(self): - super().test_repr() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_newlines(self): - super().test_newlines() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_newlines_input(self): - super().test_newlines_input() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_write_through(self): - super().test_reconfigure_write_through() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_write_fromascii(self): - super().test_reconfigure_write_fromascii() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_write(self): - super().test_reconfigure_write() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_defaults(self): - super().test_reconfigure_defaults() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_newline(self): - super().test_reconfigure_newline() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_errors(self): - super().test_reconfigure_errors() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reconfigure_locale(self): - super().test_reconfigure_locale() - - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_initialization(self): r = self.BytesIO(b"\xc3\xa9\n\n") b = self.BufferedReader(r, 1000) @@ -4168,8 +4122,7 @@ def test_initialization(self): t = self.TextIOWrapper.__new__(self.TextIOWrapper) self.assertRaises(Exception, repr, t) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_garbage_collection(self): # C TextIOWrapper objects are collected, and collecting them flushes # all data to disk. @@ -4233,20 +4186,121 @@ def write(self, data): t.write("x"*chunk_size) self.assertEqual([b"abcdef", b"ghi", b"x"*chunk_size], buf._write_stack) + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_issue119506(self): + chunk_size = 8192 + + class MockIO(self.MockRawIO): + written = False + def write(self, data): + if not self.written: + self.written = True + t.write("middle") + return super().write(data) + + buf = MockIO() + t = self.TextIOWrapper(buf) + t.write("abc") + t.write("def") + # writing data which size >= chunk_size cause flushing buffer before write. + t.write("g" * chunk_size) + t.flush() + + self.assertEqual([b"abcdef", b"middle", b"g"*chunk_size], + buf._write_stack) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_basic_io(self): + return super().test_basic_io() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_constructor(self): + return super().test_constructor() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_detach(self): + return super().test_detach() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_newlines(self): + return super().test_newlines() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_newlines_input(self): + return super().test_newlines_input() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_non_text_encoding_codecs_are_rejected(self): + return super().test_non_text_encoding_codecs_are_rejected() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_defaults(self): + return super().test_reconfigure_defaults() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_encoding_read(self): + return super().test_reconfigure_encoding_read() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_errors(self): + return super().test_reconfigure_errors() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_line_buffering(self): + return super().test_reconfigure_line_buffering() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_locale(self): + return super().test_reconfigure_locale() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_newline(self): + return super().test_reconfigure_newline() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_write(self): + return super().test_reconfigure_write() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_write_fromascii(self): + return super().test_reconfigure_write_fromascii() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_reconfigure_write_through(self): + return super().test_reconfigure_write_through() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_repr(self): + return super().test_repr() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_telling(self): + return super().test_telling() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_uninitialized(self): + return super().test_uninitialized() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_recursive_repr(self): + return super().test_recursive_repr() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickling_subclass(self): + return super().test_pickling_subclass() + class PyTextIOWrapperTest(TextIOWrapperTest): io = pyio shutdown_error = "LookupError: unknown encoding: ascii" - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: ValueError not raised def test_constructor(self): - super().test_constructor() + return super().test_constructor() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_newlines(self): - super().test_newlines() + return super().test_newlines() class IncrementalNewlineDecoderTest(unittest.TestCase): @@ -4326,8 +4380,7 @@ def _decode_bytewise(s): self.assertEqual(decoder.decode(input), "abc") self.assertEqual(decoder.newlines, None) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_newline_decoder(self): encodings = ( # None meaning the IncrementalNewlineDecoder takes unicode input @@ -4489,8 +4542,7 @@ def test_io_after_close(self): self.assertRaises(ValueError, f.writelines, []) self.assertRaises(ValueError, next, f) - # TODO: RUSTPYTHON, cyclic gc - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; cyclic gc def test_blockingioerror(self): # Various BlockingIOError issues class C(str): @@ -4538,15 +4590,14 @@ def test_abc_inheritance_official(self): self._check_abc_inheritance(io) def _check_warn_on_dealloc(self, *args, **kwargs): - f = open(*args, **kwargs) + f = self.open(*args, **kwargs) r = repr(f) with self.assertWarns(ResourceWarning) as cm: f = None support.gc_collect() self.assertIn(r, str(cm.warning.args[0])) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_warn_on_dealloc(self): self._check_warn_on_dealloc(os_helper.TESTFN, "wb", buffering=0) self._check_warn_on_dealloc(os_helper.TESTFN, "wb") @@ -4569,10 +4620,9 @@ def cleanup_fds(): r, w = os.pipe() fds += r, w with warnings_helper.check_no_resource_warning(self): - open(r, *args, closefd=False, **kwargs) + self.open(r, *args, closefd=False, **kwargs) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()") def test_warn_on_dealloc_fd(self): self._check_warn_on_dealloc_fd("rb", buffering=0) @@ -4602,16 +4652,14 @@ def test_pickling(self): with self.assertRaisesRegex(TypeError, msg): pickle.dumps(f, protocol) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipIf( support.is_emscripten, "fstat() of a pipe fd is not supported" ) def test_nonblock_pipe_write_bigbuf(self): self._test_nonblock_pipe_write(16*1024) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipIf( support.is_emscripten, "fstat() of a pipe fd is not supported" ) @@ -4733,8 +4781,7 @@ def test_check_encoding_errors(self): proc = assert_python_failure('-X', 'dev', '-c', code) self.assertEqual(proc.rc, 10, proc) - # TODO: RUSTPYTHON, AssertionError: 0 != 2 - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 0 != 2 def test_check_encoding_warning(self): # PEP 597: Raise warning when encoding is not specified # and sys.flags.warn_default_encoding is set. @@ -4758,8 +4805,7 @@ def test_check_encoding_warning(self): self.assertTrue( warnings[1].startswith(b":8: EncodingWarning: ")) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_text_encoding(self): # PEP 597, bpo-47000. io.text_encoding() returns "locale" or "utf-8" # based on sys.flags.utf8_mode @@ -4837,10 +4883,9 @@ def test_daemon_threads_shutdown_stdout_deadlock(self): def test_daemon_threads_shutdown_stderr_deadlock(self): self.check_daemon_threads_shutdown_deadlock('stderr') - # TODO: RUSTPYTHON, AssertionError: 22 != 10 : _PythonRunResult(rc=22, out=b'', err=b'') - @unittest.expectedFailure - def test_check_encoding_errors(self): # TODO: RUSTPYTHON, remove when this passes - super().test_check_encoding_errors() # TODO: RUSTPYTHON, remove when this passes + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 22 != 10 : _PythonRunResult(rc=22, out=b'', err=b'') + def test_check_encoding_errors(self): + return super().test_check_encoding_errors() class PyMiscIOTest(MiscIOTest): @@ -5014,16 +5059,14 @@ def alarm_handler(sig, frame): os.close(w) os.close(r) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @requires_alarm @support.requires_resource('walltime') def test_interrupted_read_retry_buffered(self): self.check_interrupted_read_retry(lambda x: x.decode('latin1'), mode="rb") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @requires_alarm @support.requires_resource('walltime') def test_interrupted_read_retry_text(self): @@ -5098,15 +5141,13 @@ def alarm2(sig, frame): if e.errno != errno.EBADF: raise - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @requires_alarm @support.requires_resource('walltime') def test_interrupted_write_retry_buffered(self): self.check_interrupted_write_retry(b"x", mode="wb") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @requires_alarm @support.requires_resource('walltime') def test_interrupted_write_retry_text(self): diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index 8a426e338a..1fe2aef39d 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -196,6 +196,7 @@ def test_rmtree_works_on_bytes(self): self.assertIsInstance(victim, bytes) shutil.rmtree(victim) + @unittest.skipIf(sys.platform == 'win32', 'TODO: RUSTPYTHON; flaky') @os_helper.skip_unless_symlink def test_rmtree_fails_on_symlink_onerror(self): tmp = self.mkdtemp() @@ -1477,7 +1478,7 @@ def test_dont_copy_file_onto_link_to_itself(self): finally: shutil.rmtree(TESTFN, ignore_errors=True) - @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON') + @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON; AssertionError: SameFileError not raised for copyfile') @os_helper.skip_unless_symlink def test_dont_copy_file_onto_symlink_to_itself(self): # bug 851123. diff --git a/Lib/test/test_uuid.py b/Lib/test/test_uuid.py index 069221ae47..4aa15f6993 100644 --- a/Lib/test/test_uuid.py +++ b/Lib/test/test_uuid.py @@ -1,6 +1,7 @@ import unittest from test import support from test.support import import_helper +from test.support.script_helper import assert_python_ok import builtins import contextlib import copy @@ -32,8 +33,7 @@ def get_command_stdout(command, args): class BaseTestUUID: uuid = None - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_safe_uuid_enum(self): class CheckedSafeUUID(enum.Enum): safe = 0 @@ -775,10 +775,37 @@ def test_cli_uuid5_ouputted_with_valid_namespace_and_name(self): class TestUUIDWithoutExtModule(BaseTestUUID, unittest.TestCase): uuid = py_uuid + @unittest.skipUnless(c_uuid, 'requires the C _uuid module') class TestUUIDWithExtModule(BaseTestUUID, unittest.TestCase): uuid = c_uuid + def check_has_stable_libuuid_extractable_node(self): + if not self.uuid._has_stable_extractable_node: + self.skipTest("libuuid cannot deduce MAC address") + + @unittest.skipUnless(os.name == 'posix', 'POSIX only') + def test_unix_getnode_from_libuuid(self): + self.check_has_stable_libuuid_extractable_node() + script = 'import uuid; print(uuid._unix_getnode())' + _, n_a, _ = assert_python_ok('-c', script) + _, n_b, _ = assert_python_ok('-c', script) + n_a, n_b = n_a.decode().strip(), n_b.decode().strip() + self.assertTrue(n_a.isdigit()) + self.assertTrue(n_b.isdigit()) + self.assertEqual(n_a, n_b) + + @unittest.skipUnless(os.name == 'nt', 'Windows only') + def test_windows_getnode_from_libuuid(self): + self.check_has_stable_libuuid_extractable_node() + script = 'import uuid; print(uuid._windll_getnode())' + _, n_a, _ = assert_python_ok('-c', script) + _, n_b, _ = assert_python_ok('-c', script) + n_a, n_b = n_a.decode().strip(), n_b.decode().strip() + self.assertTrue(n_a.isdigit()) + self.assertTrue(n_b.isdigit()) + self.assertEqual(n_a, n_b) + class BaseTestInternals: _uuid = py_uuid diff --git a/Lib/uuid.py b/Lib/uuid.py index c286eac38e..55f46eb510 100644 --- a/Lib/uuid.py +++ b/Lib/uuid.py @@ -572,39 +572,43 @@ def _netstat_getnode(): try: import _uuid _generate_time_safe = getattr(_uuid, "generate_time_safe", None) + _has_stable_extractable_node = getattr(_uuid, "has_stable_extractable_node", False) _UuidCreate = getattr(_uuid, "UuidCreate", None) except ImportError: _uuid = None _generate_time_safe = None + _has_stable_extractable_node = False _UuidCreate = None def _unix_getnode(): """Get the hardware address on Unix using the _uuid extension module.""" - if _generate_time_safe: + if _generate_time_safe and _has_stable_extractable_node: uuid_time, _ = _generate_time_safe() return UUID(bytes=uuid_time).node def _windll_getnode(): """Get the hardware address on Windows using the _uuid extension module.""" - if _UuidCreate: + if _UuidCreate and _has_stable_extractable_node: uuid_bytes = _UuidCreate() return UUID(bytes_le=uuid_bytes).node def _random_getnode(): """Get a random node ID.""" - # RFC 4122, $4.1.6 says "For systems with no IEEE address, a randomly or - # pseudo-randomly generated value may be used; see Section 4.5. The - # multicast bit must be set in such addresses, in order that they will - # never conflict with addresses obtained from network cards." + # RFC 9562, §6.10-3 says that + # + # Implementations MAY elect to obtain a 48-bit cryptographic-quality + # random number as per Section 6.9 to use as the Node ID. [...] [and] + # implementations MUST set the least significant bit of the first octet + # of the Node ID to 1. This bit is the unicast or multicast bit, which + # will never be set in IEEE 802 addresses obtained from network cards. # # The "multicast bit" of a MAC address is defined to be "the least # significant bit of the first octet". This works out to be the 41st bit # counting from 1 being the least significant bit, or 1<<40. # # See https://en.wikipedia.org/w/index.php?title=MAC_address&oldid=1128764812#Universal_vs._local_(U/L_bit) - import random - return random.getrandbits(48) | (1 << 40) + return int.from_bytes(os.urandom(6)) | (1 << 40) # _OS_GETTERS, when known, are targeted for a specific OS or platform. diff --git a/compiler/codegen/src/ir.rs b/compiler/codegen/src/ir.rs index 1cc59dd656..31c8926091 100644 --- a/compiler/codegen/src/ir.rs +++ b/compiler/codegen/src/ir.rs @@ -5,7 +5,7 @@ use rustpython_compiler_core::{ OneIndexed, SourceLocation, bytecode::{ CodeFlags, CodeObject, CodeUnit, ConstantData, InstrDisplayContext, Instruction, Label, - OpArg, + OpArg, PyCodeLocationInfoKind, }, }; @@ -72,6 +72,7 @@ pub struct InstructionInfo { pub target: BlockIdx, // pub range: TextRange, pub location: SourceLocation, + // TODO: end_location for debug ranges } // spell-checker:ignore petgraph @@ -199,6 +200,9 @@ impl CodeInfo { locations.clear() } + // Generate linetable from locations + let linetable = generate_linetable(&locations, first_line_number.get() as i32); + Ok(CodeObject { flags, posonlyarg_count, @@ -218,6 +222,8 @@ impl CodeInfo { cellvars: cellvar_cache.into_iter().collect(), freevars: freevar_cache.into_iter().collect(), cell2arg, + linetable, + exceptiontable: Box::new([]), // TODO: Generate actual exception table }) } @@ -388,3 +394,134 @@ fn iter_blocks(blocks: &[Block]) -> impl Iterator + ' Some((idx, b)) }) } + +/// Generate CPython 3.11+ format linetable from source locations +fn generate_linetable(locations: &[SourceLocation], first_line: i32) -> Box<[u8]> { + if locations.is_empty() { + return Box::new([]); + } + + let mut linetable = Vec::new(); + // Initialize prev_line to first_line + // The first entry's delta is relative to co_firstlineno + let mut prev_line = first_line; + let mut i = 0; + + while i < locations.len() { + let loc = &locations[i]; + + // Count consecutive instructions with the same location + let mut length = 1; + while i + length < locations.len() && locations[i + length] == locations[i] { + length += 1; + } + + // Process in chunks of up to 8 instructions + while length > 0 { + let entry_length = length.min(8); + + // Get line and column information + // SourceLocation always has row and column (both are OneIndexed) + let line = loc.row.get() as i32; + let col = (loc.column.get() as i32) - 1; // Convert 1-based to 0-based + + let line_delta = line - prev_line; + + // Choose the appropriate encoding based on line delta and column info + // Note: SourceLocation always has valid column, so we never get NO_COLUMNS case + if line_delta == 0 { + let end_col = col; // Use same column for end (no range info available) + + if col < 80 && end_col - col < 16 && end_col >= col { + // Short form (codes 0-9) for common cases + let code = (col / 8).min(9) as u8; // Short0 to Short9 + linetable.push(0x80 | (code << 3) | ((entry_length - 1) as u8)); + let col_byte = (((col % 8) as u8) << 4) | ((end_col - col) as u8 & 0xf); + linetable.push(col_byte); + } else if col < 128 && end_col < 128 { + // One-line form (code 10) for same line + linetable.push( + 0x80 | ((PyCodeLocationInfoKind::OneLine0 as u8) << 3) + | ((entry_length - 1) as u8), + ); + linetable.push(col as u8); + linetable.push(end_col as u8); + } else { + // Long form for columns >= 128 + linetable.push( + 0x80 | ((PyCodeLocationInfoKind::Long as u8) << 3) + | ((entry_length - 1) as u8), + ); + write_signed_varint(&mut linetable, 0); // line_delta = 0 + write_varint(&mut linetable, 0); // end_line delta = 0 + write_varint(&mut linetable, (col as u32) + 1); // column + 1 for encoding + write_varint(&mut linetable, (end_col as u32) + 1); // end_col + 1 + } + } else if line_delta > 0 && line_delta < 3 + /* && column.is_some() */ + { + // One-line form (codes 11-12) for line deltas 1-2 + let end_col = col; // Use same column for end + + if col < 128 && end_col < 128 { + let code = (PyCodeLocationInfoKind::OneLine0 as u8) + (line_delta as u8); // 11 for delta=1, 12 for delta=2 + linetable.push(0x80 | (code << 3) | ((entry_length - 1) as u8)); + linetable.push(col as u8); + linetable.push(end_col as u8); + } else { + // Long form for columns >= 128 or negative line delta + linetable.push( + 0x80 | ((PyCodeLocationInfoKind::Long as u8) << 3) + | ((entry_length - 1) as u8), + ); + write_signed_varint(&mut linetable, line_delta); + write_varint(&mut linetable, 0); // end_line delta = 0 + write_varint(&mut linetable, (col as u32) + 1); // column + 1 for encoding + write_varint(&mut linetable, (end_col as u32) + 1); // end_col + 1 + } + } else { + // Long form (code 14) for all other cases + // This handles: line_delta < 0, line_delta >= 3, or columns >= 128 + let end_col = col; // Use same column for end + linetable.push( + 0x80 | ((PyCodeLocationInfoKind::Long as u8) << 3) | ((entry_length - 1) as u8), + ); + write_signed_varint(&mut linetable, line_delta); + write_varint(&mut linetable, 0); // end_line delta = 0 + write_varint(&mut linetable, (col as u32) + 1); // column + 1 for encoding + write_varint(&mut linetable, (end_col as u32) + 1); // end_col + 1 + } + + prev_line = line; + length -= entry_length; + i += entry_length; + } + } + + linetable.into_boxed_slice() +} + +/// Write a variable-length unsigned integer (6-bit chunks) +/// Returns the number of bytes written +fn write_varint(buf: &mut Vec, mut val: u32) -> usize { + let start_len = buf.len(); + while val >= 64 { + buf.push(0x40 | (val & 0x3f) as u8); + val >>= 6; + } + buf.push(val as u8); + buf.len() - start_len +} + +/// Write a variable-length signed integer +/// Returns the number of bytes written +fn write_signed_varint(buf: &mut Vec, val: i32) -> usize { + let uval = if val < 0 { + // (unsigned int)(-val) has an undefined behavior for INT_MIN + // So we use (0 - val as u32) to handle it correctly + ((0u32.wrapping_sub(val as u32)) << 1) | 1 + } else { + (val as u32) << 1 + }; + write_varint(buf, uval) +} diff --git a/compiler/core/src/bytecode.rs b/compiler/core/src/bytecode.rs index b38c599508..c2ce4e52c0 100644 --- a/compiler/core/src/bytecode.rs +++ b/compiler/core/src/bytecode.rs @@ -33,6 +33,75 @@ pub enum ResumeType { AfterAwait = 3, } +/// CPython 3.11+ linetable location info codes +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum PyCodeLocationInfoKind { + // Short forms are 0 to 9 + Short0 = 0, + Short1 = 1, + Short2 = 2, + Short3 = 3, + Short4 = 4, + Short5 = 5, + Short6 = 6, + Short7 = 7, + Short8 = 8, + Short9 = 9, + // One line forms are 10 to 12 + OneLine0 = 10, + OneLine1 = 11, + OneLine2 = 12, + NoColumns = 13, + Long = 14, + None = 15, +} + +impl PyCodeLocationInfoKind { + pub fn from_code(code: u8) -> Option { + match code { + 0 => Some(Self::Short0), + 1 => Some(Self::Short1), + 2 => Some(Self::Short2), + 3 => Some(Self::Short3), + 4 => Some(Self::Short4), + 5 => Some(Self::Short5), + 6 => Some(Self::Short6), + 7 => Some(Self::Short7), + 8 => Some(Self::Short8), + 9 => Some(Self::Short9), + 10 => Some(Self::OneLine0), + 11 => Some(Self::OneLine1), + 12 => Some(Self::OneLine2), + 13 => Some(Self::NoColumns), + 14 => Some(Self::Long), + 15 => Some(Self::None), + _ => Option::None, + } + } + + pub fn is_short(&self) -> bool { + (*self as u8) <= 9 + } + + pub fn short_column_group(&self) -> Option { + if self.is_short() { + Some(*self as u8) + } else { + Option::None + } + } + + pub fn one_line_delta(&self) -> Option { + match self { + Self::OneLine0 => Some(0), + Self::OneLine1 => Some(1), + Self::OneLine2 => Some(2), + _ => Option::None, + } + } +} + pub trait Constant: Sized { type Name: AsRef; @@ -146,6 +215,10 @@ pub struct CodeObject { pub varnames: Box<[C::Name]>, pub cellvars: Box<[C::Name]>, pub freevars: Box<[C::Name]>, + pub linetable: Box<[u8]>, + // Line number table (CPython 3.11+ format) + pub exceptiontable: Box<[u8]>, + // Exception handling table } bitflags! { @@ -1202,6 +1275,8 @@ impl CodeObject { first_line_number: self.first_line_number, max_stackdepth: self.max_stackdepth, cell2arg: self.cell2arg, + linetable: self.linetable, + exceptiontable: self.exceptiontable, } } @@ -1232,6 +1307,8 @@ impl CodeObject { first_line_number: self.first_line_number, max_stackdepth: self.max_stackdepth, cell2arg: self.cell2arg.clone(), + linetable: self.linetable.clone(), + exceptiontable: self.exceptiontable.clone(), } } } diff --git a/compiler/core/src/marshal.rs b/compiler/core/src/marshal.rs index ff82340c0e..b8044a1ab9 100644 --- a/compiler/core/src/marshal.rs +++ b/compiler/core/src/marshal.rs @@ -251,6 +251,16 @@ pub fn deserialize_code( let cellvars = read_names()?; let freevars = read_names()?; + // Read linetable and exceptiontable + let linetable_len = rdr.read_u32()?; + let linetable = rdr.read_slice(linetable_len)?.to_vec().into_boxed_slice(); + + let exceptiontable_len = rdr.read_u32()?; + let exceptiontable = rdr + .read_slice(exceptiontable_len)? + .to_vec() + .into_boxed_slice(); + Ok(CodeObject { instructions, locations, @@ -269,6 +279,8 @@ pub fn deserialize_code( varnames, cellvars, freevars, + linetable, + exceptiontable, }) } @@ -684,4 +696,8 @@ pub fn serialize_code(buf: &mut W, code: &CodeObject) write_names(&code.varnames); write_names(&code.cellvars); write_names(&code.freevars); + + // Serialize linetable and exceptiontable + write_vec(buf, &code.linetable); + write_vec(buf, &code.exceptiontable); } diff --git a/stdlib/src/uuid.rs b/stdlib/src/uuid.rs index 9b0e23a81c..3f75db402c 100644 --- a/stdlib/src/uuid.rs +++ b/stdlib/src/uuid.rs @@ -30,4 +30,7 @@ mod _uuid { fn has_uuid_generate_time_safe(_vm: &VirtualMachine) -> u32 { 0 } + + #[pyattr(name = "has_stable_extractable_node")] + const HAS_STABLE_EXTRACTABLE_NODE: bool = false; } diff --git a/vm/src/builtins/code.rs b/vm/src/builtins/code.rs index 1ce0f3b3e0..b49f76caa0 100644 --- a/vm/src/builtins/code.rs +++ b/vm/src/builtins/code.rs @@ -16,8 +16,125 @@ use crate::{ use malachite_bigint::BigInt; use num_traits::Zero; use rustpython_compiler_core::OneIndexed; +use rustpython_compiler_core::bytecode::PyCodeLocationInfoKind; use std::{borrow::Borrow, fmt, ops::Deref}; +/// State for iterating through code address ranges +struct PyCodeAddressRange<'a> { + ar_start: i32, + ar_end: i32, + ar_line: i32, + computed_line: i32, + reader: LineTableReader<'a>, +} + +impl<'a> PyCodeAddressRange<'a> { + fn new(linetable: &'a [u8], first_line: i32) -> Self { + PyCodeAddressRange { + ar_start: 0, + ar_end: 0, + ar_line: -1, + computed_line: first_line, + reader: LineTableReader::new(linetable), + } + } + + /// Check if this is a NO_LINE marker (code 15) + fn is_no_line_marker(byte: u8) -> bool { + (byte >> 3) == 0x1f + } + + /// Advance to next address range + fn advance(&mut self) -> bool { + if self.reader.at_end() { + return false; + } + + let first_byte = match self.reader.read_byte() { + Some(b) => b, + None => return false, + }; + + if (first_byte & 0x80) == 0 { + return false; // Invalid linetable + } + + let code = (first_byte >> 3) & 0x0f; + let length = ((first_byte & 0x07) + 1) as i32; + + // Get line delta for this entry + let line_delta = self.get_line_delta(code); + + // Update computed line + self.computed_line += line_delta; + + // Check for NO_LINE marker + if Self::is_no_line_marker(first_byte) { + self.ar_line = -1; + } else { + self.ar_line = self.computed_line; + } + + // Update address range + self.ar_start = self.ar_end; + self.ar_end += length * 2; // sizeof(_Py_CODEUNIT) = 2 + + // Skip remaining bytes for this entry + while !self.reader.at_end() { + if let Some(b) = self.reader.peek_byte() { + if (b & 0x80) != 0 { + break; + } + self.reader.read_byte(); + } else { + break; + } + } + + true + } + + fn get_line_delta(&mut self, code: u8) -> i32 { + let kind = match PyCodeLocationInfoKind::from_code(code) { + Some(k) => k, + None => return 0, + }; + + match kind { + PyCodeLocationInfoKind::None => 0, // NO_LINE marker + PyCodeLocationInfoKind::Long => { + let delta = self.reader.read_signed_varint(); + // Skip end_line, col, end_col + self.reader.read_varint(); + self.reader.read_varint(); + self.reader.read_varint(); + delta + } + PyCodeLocationInfoKind::NoColumns => self.reader.read_signed_varint(), + PyCodeLocationInfoKind::OneLine0 => { + self.reader.read_byte(); // Skip column + self.reader.read_byte(); // Skip end column + 0 + } + PyCodeLocationInfoKind::OneLine1 => { + self.reader.read_byte(); // Skip column + self.reader.read_byte(); // Skip end column + 1 + } + PyCodeLocationInfoKind::OneLine2 => { + self.reader.read_byte(); // Skip column + self.reader.read_byte(); // Skip end column + 2 + } + _ if kind.is_short() => { + self.reader.read_byte(); // Skip column byte + 0 + } + _ => 0, + } + } +} + #[derive(FromArgs)] pub struct ReplaceArgs { #[pyarg(named, optional)] @@ -40,6 +157,22 @@ pub struct ReplaceArgs { co_flags: OptionalArg, #[pyarg(named, optional)] co_varnames: OptionalArg>, + #[pyarg(named, optional)] + co_nlocals: OptionalArg, + #[pyarg(named, optional)] + co_stacksize: OptionalArg, + #[pyarg(named, optional)] + co_code: OptionalArg, + #[pyarg(named, optional)] + co_linetable: OptionalArg, + #[pyarg(named, optional)] + co_exceptiontable: OptionalArg, + #[pyarg(named, optional)] + co_freevars: OptionalArg>, + #[pyarg(named, optional)] + co_cellvars: OptionalArg>, + #[pyarg(named, optional)] + co_qualname: OptionalArg, } #[derive(Clone)] @@ -350,6 +483,211 @@ impl PyCode { vm.ctx.new_tuple(names) } + #[pygetset] + pub fn co_linetable(&self, vm: &VirtualMachine) -> crate::builtins::PyBytesRef { + // Return the actual linetable from the code object + vm.ctx.new_bytes(self.code.linetable.to_vec()) + } + + #[pygetset] + pub fn co_exceptiontable(&self, vm: &VirtualMachine) -> crate::builtins::PyBytesRef { + // Return the actual exception table from the code object + vm.ctx.new_bytes(self.code.exceptiontable.to_vec()) + } + + #[pymethod] + pub fn co_lines(&self, vm: &VirtualMachine) -> PyResult { + // TODO: Implement lazy iterator (lineiterator) like CPython for better performance + // Currently returns eager list for simplicity + + // Return an iterator over (start_offset, end_offset, lineno) tuples + let linetable = self.code.linetable.as_ref(); + let mut lines = Vec::new(); + + if !linetable.is_empty() { + let first_line = self.code.first_line_number.map_or(0, |n| n.get() as i32); + let mut range = PyCodeAddressRange::new(linetable, first_line); + + // Process all address ranges and merge consecutive entries with same line + let mut pending_entry: Option<(i32, i32, i32)> = None; + + while range.advance() { + let start = range.ar_start; + let end = range.ar_end; + let line = range.ar_line; + + if let Some((prev_start, _, prev_line)) = pending_entry { + if prev_line == line { + // Same line, extend the range + pending_entry = Some((prev_start, end, prev_line)); + } else { + // Different line, emit the previous entry + let tuple = if prev_line == -1 { + vm.ctx.new_tuple(vec![ + vm.ctx.new_int(prev_start).into(), + vm.ctx.new_int(start).into(), + vm.ctx.none(), + ]) + } else { + vm.ctx.new_tuple(vec![ + vm.ctx.new_int(prev_start).into(), + vm.ctx.new_int(start).into(), + vm.ctx.new_int(prev_line).into(), + ]) + }; + lines.push(tuple.into()); + pending_entry = Some((start, end, line)); + } + } else { + // First entry + pending_entry = Some((start, end, line)); + } + } + + // Emit the last pending entry + if let Some((start, end, line)) = pending_entry { + let tuple = if line == -1 { + vm.ctx.new_tuple(vec![ + vm.ctx.new_int(start).into(), + vm.ctx.new_int(end).into(), + vm.ctx.none(), + ]) + } else { + vm.ctx.new_tuple(vec![ + vm.ctx.new_int(start).into(), + vm.ctx.new_int(end).into(), + vm.ctx.new_int(line).into(), + ]) + }; + lines.push(tuple.into()); + } + } + + let list = vm.ctx.new_list(lines); + vm.call_method(list.as_object(), "__iter__", ()) + } + + #[pymethod] + pub fn co_positions(&self, vm: &VirtualMachine) -> PyResult { + // Return an iterator over (line, end_line, column, end_column) tuples for each instruction + let linetable = self.code.linetable.as_ref(); + let mut positions = Vec::new(); + + if !linetable.is_empty() { + let mut reader = LineTableReader::new(linetable); + let mut line = self.code.first_line_number.map_or(0, |n| n.get() as i32); + + while !reader.at_end() { + let first_byte = match reader.read_byte() { + Some(b) => b, + None => break, + }; + + if (first_byte & 0x80) == 0 { + break; // Invalid linetable + } + + let code = (first_byte >> 3) & 0x0f; + let length = ((first_byte & 0x07) + 1) as i32; + + let kind = match PyCodeLocationInfoKind::from_code(code) { + Some(k) => k, + None => break, // Invalid code + }; + + let (line_delta, end_line_delta, column, end_column): ( + i32, + i32, + Option, + Option, + ) = match kind { + PyCodeLocationInfoKind::None => { + // No location - all values are None + (0, 0, None, None) + } + PyCodeLocationInfoKind::Long => { + // Long form + let delta = reader.read_signed_varint(); + let end_line_delta = reader.read_varint() as i32; + + let col = reader.read_varint(); + let column = if col == 0 { + None + } else { + Some((col - 1) as i32) + }; + + let end_col = reader.read_varint(); + let end_column = if end_col == 0 { + None + } else { + Some((end_col - 1) as i32) + }; + + // endline = line + end_line_delta (will be computed after line update) + (delta, end_line_delta, column, end_column) + } + PyCodeLocationInfoKind::NoColumns => { + // No column form + let delta = reader.read_signed_varint(); + (delta, 0, None, None) // endline will be same as line (delta = 0) + } + PyCodeLocationInfoKind::OneLine0 + | PyCodeLocationInfoKind::OneLine1 + | PyCodeLocationInfoKind::OneLine2 => { + // One-line form - endline = line + let col = reader.read_byte().unwrap_or(0) as i32; + let end_col = reader.read_byte().unwrap_or(0) as i32; + let delta = kind.one_line_delta().unwrap_or(0); + (delta, 0, Some(col), Some(end_col)) // endline = line (delta = 0) + } + _ if kind.is_short() => { + // Short form - endline = line + let col_data = reader.read_byte().unwrap_or(0); + let col_group = kind.short_column_group().unwrap_or(0); + let col = ((col_group as i32) << 3) | ((col_data >> 4) as i32); + let end_col = col + (col_data & 0x0f) as i32; + (0, 0, Some(col), Some(end_col)) // endline = line (delta = 0) + } + _ => (0, 0, None, None), + }; + + // Update line number + line += line_delta; + + // Generate position tuples for each instruction covered by this entry + for _ in 0..length { + // Handle special case for no location (code 15) + let final_line = if kind == PyCodeLocationInfoKind::None { + None + } else { + Some(line) + }; + + let final_endline = if kind == PyCodeLocationInfoKind::None { + None + } else { + Some(line + end_line_delta) + }; + + // Convert Option to PyObject (None or int) + let line_obj = final_line.to_pyobject(vm); + let end_line_obj = final_endline.to_pyobject(vm); + let column_obj = column.to_pyobject(vm); + let end_column_obj = end_column.to_pyobject(vm); + + let tuple = + vm.ctx + .new_tuple(vec![line_obj, end_line_obj, column_obj, end_column_obj]); + positions.push(tuple.into()); + } + } + } + + let list = vm.ctx.new_list(positions); + vm.call_method(list.as_object(), "__iter__", ()) + } + #[pymethod] pub fn replace(&self, args: ReplaceArgs, vm: &VirtualMachine) -> PyResult { let posonlyarg_count = match args.co_posonlyargcount { @@ -408,6 +746,66 @@ impl PyCode { OptionalArg::Missing => self.code.varnames.iter().map(|s| s.to_object()).collect(), }; + let qualname = match args.co_qualname { + OptionalArg::Present(qualname) => qualname, + OptionalArg::Missing => self.code.qualname.to_owned(), + }; + + let max_stackdepth = match args.co_stacksize { + OptionalArg::Present(stacksize) => stacksize, + OptionalArg::Missing => self.code.max_stackdepth, + }; + + let instructions = match args.co_code { + OptionalArg::Present(_code_bytes) => { + // Convert bytes back to instructions + // For now, keep the original instructions + // TODO: Properly parse bytecode from bytes + self.code.instructions.clone() + } + OptionalArg::Missing => self.code.instructions.clone(), + }; + + let cellvars = match args.co_cellvars { + OptionalArg::Present(cellvars) => cellvars + .into_iter() + .map(|o| o.as_interned_str(vm).unwrap()) + .collect(), + OptionalArg::Missing => self.code.cellvars.clone(), + }; + + let freevars = match args.co_freevars { + OptionalArg::Present(freevars) => freevars + .into_iter() + .map(|o| o.as_interned_str(vm).unwrap()) + .collect(), + OptionalArg::Missing => self.code.freevars.clone(), + }; + + // Validate co_nlocals if provided + if let OptionalArg::Present(nlocals) = args.co_nlocals + && nlocals as usize != varnames.len() + { + return Err(vm.new_value_error(format!( + "co_nlocals ({}) != len(co_varnames) ({})", + nlocals, + varnames.len() + ))); + } + + // Handle linetable and exceptiontable + let linetable = match args.co_linetable { + OptionalArg::Present(linetable) => linetable.as_bytes().to_vec().into_boxed_slice(), + OptionalArg::Missing => self.code.linetable.clone(), + }; + + let exceptiontable = match args.co_exceptiontable { + OptionalArg::Present(exceptiontable) => { + exceptiontable.as_bytes().to_vec().into_boxed_slice() + } + OptionalArg::Missing => self.code.exceptiontable.clone(), + }; + Ok(Self { code: CodeObject { flags: CodeFlags::from_bits_truncate(flags), @@ -417,10 +815,10 @@ impl PyCode { source_path: source_path.as_object().as_interned_str(vm).unwrap(), first_line_number, obj_name: obj_name.as_object().as_interned_str(vm).unwrap(), - qualname: self.code.qualname, + qualname: qualname.as_object().as_interned_str(vm).unwrap(), - max_stackdepth: self.code.max_stackdepth, - instructions: self.code.instructions.clone(), + max_stackdepth, + instructions, locations: self.code.locations.clone(), constants: constants.into_iter().map(Literal).collect(), names: names @@ -431,9 +829,11 @@ impl PyCode { .into_iter() .map(|o| o.as_interned_str(vm).unwrap()) .collect(), - cellvars: self.code.cellvars.clone(), - freevars: self.code.freevars.clone(), + cellvars, + freevars, cell2arg: self.code.cell2arg.clone(), + linetable, + exceptiontable, }, }) } @@ -457,6 +857,69 @@ impl ToPyObject for bytecode::CodeObject { } } +// Helper struct for reading linetable +struct LineTableReader<'a> { + data: &'a [u8], + pos: usize, +} + +impl<'a> LineTableReader<'a> { + fn new(data: &'a [u8]) -> Self { + Self { data, pos: 0 } + } + + fn read_byte(&mut self) -> Option { + if self.pos < self.data.len() { + let byte = self.data[self.pos]; + self.pos += 1; + Some(byte) + } else { + None + } + } + + fn peek_byte(&self) -> Option { + if self.pos < self.data.len() { + Some(self.data[self.pos]) + } else { + None + } + } + + fn read_varint(&mut self) -> u32 { + if let Some(first) = self.read_byte() { + let mut val = (first & 0x3f) as u32; + let mut shift = 0; + let mut byte = first; + while (byte & 0x40) != 0 { + if let Some(next) = self.read_byte() { + shift += 6; + val |= ((next & 0x3f) as u32) << shift; + byte = next; + } else { + break; + } + } + val + } else { + 0 + } + } + + fn read_signed_varint(&mut self) -> i32 { + let uval = self.read_varint(); + if uval & 1 != 0 { + -((uval >> 1) as i32) + } else { + (uval >> 1) as i32 + } + } + + fn at_end(&self) -> bool { + self.pos >= self.data.len() + } +} + pub fn init(ctx: &Context) { PyCode::extend_class(ctx, ctx.types.code_type); }