diff --git a/.gitattributes b/.gitattributes index aa993ad110..0c4ae3d850 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,3 +4,4 @@ Cargo.lock linguist-generated -merge vm/src/stdlib/ast/gen.rs linguist-generated -merge Lib/*.py text working-tree-encoding=UTF-8 eol=LF **/*.rs text working-tree-encoding=UTF-8 eol=LF +*.pck binary diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ee8d274c18..23f47b574c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,14 +15,19 @@ concurrency: cancel-in-progress: true env: - CARGO_ARGS: --no-default-features --features stdlib,zlib,importlib,encodings,sqlite,ssl + CARGO_ARGS: --no-default-features --features stdlib,importlib,encodings,sqlite,ssl # Skip additional tests on Windows. They are checked on Linux and MacOS. + # test_glob: many failing tests + # test_io: many failing tests + # test_os: many failing tests + # test_pathlib: support.rmtree() failing + # test_posixpath: OSError: (22, 'The filename, directory name, or volume label syntax is incorrect. (os error 123)') + # test_venv: couple of failing tests WINDOWS_SKIPS: >- - test_datetime test_glob - test_importlib test_io test_os + test_rlcompleter test_pathlib test_posixpath test_venv @@ -34,7 +39,7 @@ env: # PLATFORM_INDEPENDENT_TESTS are tests that do not depend on the underlying OS. They are currently # only run on Linux to speed up the CI. PLATFORM_INDEPENDENT_TESTS: >- - test_argparse + test__colorize test_array test_asyncgen test_binop @@ -59,7 +64,6 @@ env: test_dis test_enumerate test_exception_variations - test_exceptions test_float test_format test_fractions @@ -100,12 +104,11 @@ env: test_tuple test_types test_unary - test_unicode test_unpack test_weakref test_yield_from # Python version targeted by the CI. - PYTHON_VERSION: "3.12.3" + PYTHON_VERSION: "3.13.1" jobs: rust_tests: @@ -378,7 +381,8 @@ jobs: with: { wabt-version: "1.0.30" } - name: check wasm32-unknown without js run: | - cargo build --release --manifest-path wasm/wasm-unknown-test/Cargo.toml --target wasm32-unknown-unknown --verbose + cd wasm/wasm-unknown-test + cargo build --release --verbose if wasm-objdump -xj Import target/wasm32-unknown-unknown/release/wasm_unknown_test.wasm; then echo "ERROR: wasm32-unknown module expects imports from the host environment" >2 fi diff --git a/.github/workflows/cron-ci.yaml b/.github/workflows/cron-ci.yaml index 0880e2d249..4b8e701ced 100644 --- a/.github/workflows/cron-ci.yaml +++ b/.github/workflows/cron-ci.yaml @@ -6,8 +6,8 @@ on: name: Periodic checks/tasks env: - CARGO_ARGS: --no-default-features --features stdlib,zlib,importlib,encodings,ssl,jit - PYTHON_VERSION: "3.12.0" + CARGO_ARGS: --no-default-features --features stdlib,importlib,encodings,ssl,jit + PYTHON_VERSION: "3.13.1" jobs: # codecov collects code coverage data from the rust tests, python snippets and python test suite. @@ -24,7 +24,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - run: sudo apt-get update && sudo apt-get -y install lcov - name: Run cargo-llvm-cov with Rust tests. - run: cargo llvm-cov --no-report --workspace --exclude rustpython_wasm --verbose --no-default-features --features stdlib,zlib,importlib,encodings,ssl,jit + run: cargo llvm-cov --no-report --workspace --exclude rustpython_wasm --verbose --no-default-features --features stdlib,importlib,encodings,ssl,jit - name: Run cargo-llvm-cov with Python snippets. run: python scripts/cargo-llvm-cov.py continue-on-error: true @@ -97,6 +97,9 @@ jobs: cd website [ -f ./_data/whats_left.temp ] && cp ./_data/whats_left.temp ./_data/whats_left_lastrun.temp cp ../whats_left.temp ./_data/whats_left.temp + rm _data/whats_left/modules.csv + cat _data/whats_left.temp | grep "(entire module)" | cut -d ' ' -f 1 | sort >> ../_data/whats_left/modules.csv + git add -A if git -c user.name="Github Actions" -c user.email="actions@github.com" commit -m "Update what is left results" --author="$GITHUB_ACTOR"; then git push diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5554ed0a8b..99035d04fc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ permissions: contents: write env: - CARGO_ARGS: --no-default-features --features stdlib,zlib,importlib,encodings,sqlite,ssl + CARGO_ARGS: --no-default-features --features stdlib,importlib,encodings,sqlite,ssl jobs: build: diff --git a/Cargo.lock b/Cargo.lock index 5d5db7f5ca..67e0813b7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,12 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 - -[[package]] -name = "Inflector" -version = "0.11.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" +version = 4 [[package]] name = "adler2" @@ -27,10 +21,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -68,9 +62,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "approx" @@ -115,9 +109,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "base64" @@ -133,9 +127,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "blake2" @@ -157,9 +151,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" dependencies = [ "memchr", "regex-automata", @@ -168,15 +162,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" [[package]] name = "byteorder" @@ -196,9 +190,9 @@ dependencies = [ [[package]] name = "bzip2-sys" -version = "0.1.11+1.0.8" +version = "0.1.12+1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +checksum = "72ebc2f1a417f01e1da30ef264ee86ae31d2dcd2d603ea283d3c244a883ca2a9" dependencies = [ "cc", "libc", @@ -222,9 +216,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.21" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" +checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" dependencies = [ "shlex", ] @@ -235,12 +229,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "cfg_aliases" version = "0.2.1" @@ -249,9 +237,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", @@ -272,7 +260,7 @@ dependencies = [ "bitflags 1.3.2", "strsim", "textwrap 0.11.0", - "unicode-width", + "unicode-width 0.1.14", "vec_map", ] @@ -287,14 +275,14 @@ dependencies = [ [[package]] name = "console" -version = "0.15.8" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" dependencies = [ "encode_unicode", - "lazy_static 1.5.0", "libc", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] @@ -307,12 +295,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "convert_case" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" - [[package]] name = "core-foundation" version = "0.9.4" @@ -331,9 +313,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -501,9 +483,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -520,15 +502,15 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -542,9 +524,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -554,24 +536,32 @@ dependencies = [ [[package]] name = "csv-core" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" dependencies = [ "memchr", ] [[package]] name = "derive_more" -version = "0.99.18" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ - "convert_case", "proc-macro2", "quote", - "rustc_version", - "syn 2.0.77", + "syn 2.0.98", + "unicode-xid", ] [[package]] @@ -620,9 +610,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" [[package]] name = "either" @@ -632,9 +622,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encode_unicode" -version = "0.3.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "endian-type" @@ -655,18 +645,18 @@ dependencies = [ [[package]] name = "equivalent" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -730,12 +720,12 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.33" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" dependencies = [ "crc32fast", - "libz-sys", + "libz-rs-sys", "miniz_oxide", ] @@ -781,12 +771,12 @@ dependencies = [ [[package]] name = "gethostname" -version = "0.2.3" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ebd34e35c46e00bb73e81363248d627782724609fe1b6396f553f68fe3862e" +checksum = "4fd4b8790c0792e3b11895efdf5f289ebe8b59107a6624f1cce68f24ff8c7035" dependencies = [ - "libc", - "winapi", + "rustix", + "windows-targets 0.52.6", ] [[package]] @@ -795,7 +785,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" dependencies = [ - "unicode-width", + "unicode-width 0.1.14", ] [[package]] @@ -807,15 +797,29 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] [[package]] -name = "glob" +name = "getrandom" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "wasm-bindgen", + "windows-targets 0.52.6", +] + +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" @@ -842,6 +846,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" + [[package]] name = "heck" version = "0.5.0" @@ -877,11 +887,11 @@ checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -909,12 +919,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -925,26 +935,27 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "insta" -version = "1.40.0" +version = "1.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6593a41c7a73841868772495db7dc1e8ecab43bb5c0b6da2059246c4b506ab60" +checksum = "71c1b125e30d93896b365e156c33dadfffab45ee8400afcbba4752f59de08a86" dependencies = [ "console", - "lazy_static 1.5.0", "linked-hash-map", + "once_cell", + "pin-project", "similar", ] [[package]] name = "is-macro" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2069faacbe981460232f880d26bf3c7634e322d49053aa48c27e3ae642f728f1" +checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4" dependencies = [ - "Inflector", + "heck", "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] @@ -965,18 +976,28 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -1005,6 +1026,12 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553" +[[package]] +name = "lambert_w" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45bf98425154bfe790a47b72ac452914f6df9ebfb202bc59e089e29db00258cf" + [[package]] name = "lazy_static" version = "0.2.11" @@ -1049,9 +1076,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libffi" @@ -1072,11 +1099,21 @@ dependencies = [ "cc", ] +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libredox" @@ -1084,7 +1121,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "libc", ] @@ -1100,14 +1137,12 @@ dependencies = [ ] [[package]] -name = "libz-sys" -version = "1.1.20" +name = "libz-rs-sys" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2d16453e800a8cf6dd2fc3eb4bc99b786a9b90c663b8559a5b1a041bf89e472" +checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" dependencies = [ - "cc", - "pkg-config", - "vcpkg", + "zlib-rs", ] [[package]] @@ -1118,9 +1153,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "lock_api" @@ -1134,9 +1169,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "lz4_flex" @@ -1149,11 +1184,11 @@ dependencies = [ [[package]] name = "mac_address" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8836fae9d0d4be2c8b4efcdd79e828a2faa058a90d005abf42f91cac5493a08e" +checksum = "c0aeb26bf5e836cc1c341c8106051b573f1766dfa05aa87f0b98be5e51b02303" dependencies = [ - "nix 0.28.0", + "nix", "winapi", ] @@ -1168,9 +1203,9 @@ dependencies = [ [[package]] name = "malachite" -version = "0.4.16" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5616515d632967cd329b6f6db96be9a03ea0b3a49cdbc45b0016803dad8a77b7" +checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5" dependencies = [ "malachite-base", "malachite-nz", @@ -1179,11 +1214,11 @@ dependencies = [ [[package]] name = "malachite-base" -version = "0.4.16" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46059721011b0458b7bd6d9179be5d0b60294281c23320c207adceaecc54d13b" +checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb" dependencies = [ - "hashbrown", + "hashbrown 0.14.5", "itertools 0.11.0", "libm", "ryu", @@ -1191,9 +1226,9 @@ dependencies = [ [[package]] name = "malachite-bigint" -version = "0.2.0" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17703a19c80bbdd0b7919f0f104f3b0597f7de4fc4e90a477c15366a5ba03faa" +checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8" dependencies = [ "derive_more", "malachite", @@ -1204,9 +1239,9 @@ dependencies = [ [[package]] name = "malachite-nz" -version = "0.4.16" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1503b27e825cabd1c3d0ff1e95a39fb2ec9eab6fd3da6cfa41aec7091d273e78" +checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7" dependencies = [ "itertools 0.11.0", "libm", @@ -1215,9 +1250,9 @@ dependencies = [ [[package]] name = "malachite-q" -version = "0.4.16" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a475503a70a3679dbe3b9b230a23622516742528ba614a7b2490f180ea9cb514" +checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191" dependencies = [ "itertools 0.11.0", "malachite-base", @@ -1272,20 +1307,20 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" dependencies = [ "adler2", ] [[package]] name = "mt19937" -version = "2.0.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12ca7f22ed370d5991a9caec16a83187e865bc8a532f889670337d5a5689e3a1" +checksum = "df7151a832e54d2d6b2c827a20e5bcdd80359281cd2c354e725d4b82e7c471de" dependencies = [ - "rand_core", + "rand_core 0.9.1", ] [[package]] @@ -1297,28 +1332,15 @@ dependencies = [ "smallvec", ] -[[package]] -name = "nix" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" -dependencies = [ - "bitflags 2.6.0", - "cfg-if", - "cfg_aliases 0.1.1", - "libc", - "memoffset", -] - [[package]] name = "nix" version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", - "cfg_aliases 0.2.1", + "cfg_aliases", "libc", "memoffset", ] @@ -1377,14 +1399,14 @@ checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" [[package]] name = "oorandom" @@ -1394,11 +1416,11 @@ checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", "foreign-types", "libc", @@ -1415,29 +1437,29 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-src" -version = "300.3.2+3.3.2" +version = "300.4.2+3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a211a18d945ef7e648cc6e0058f4c548ee46aab922ea203e0d30e966ea23647b" +checksum = "168ce4e058f975fe43e89d9ccf78ca668601887ae736090aacc23ae353c298e2" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" dependencies = [ "cc", "libc", @@ -1454,9 +1476,9 @@ checksum = "978aa494585d3ca4ad74929863093e87cac9790d81fe7aba2b3dc2890643a0fc" [[package]] name = "page_size" -version = "0.4.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eebde548fbbf1ea81a99b128872779c437752fb99f217c45245e1a61dcd9edcd" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" dependencies = [ "libc", "winapi", @@ -1480,7 +1502,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.5", + "redox_syscall 0.5.8", "smallvec", "windows-targets 0.52.6", ] @@ -1493,18 +1515,18 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -1512,21 +1534,41 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher 1.0.1", +] + +[[package]] +name = "pin-project" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d" dependencies = [ - "siphasher", + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] @@ -1571,14 +1613,14 @@ checksum = "52a40bc70c2c58040d2d8b167ba9a5ff59fc9dab7ad44771cfde3dcfde7a09c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] name = "portable-atomic" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" [[package]] name = "ppv-lite86" @@ -1586,29 +1628,33 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] [[package]] name = "puruspe" -version = "0.2.5" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3804877ffeba468c806c2ad9057bbbae92e4b2c410c2f108baaa0042f241fa4c" +checksum = "d76c522e44709f541a403db419a7e34d6fbbc8e6b208589ae29a030cddeefd96" +dependencies = [ + "lambert_w", + "num-complex", +] [[package]] name = "pyo3" -version = "0.22.4" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00e89ce2565d6044ca31a3eb79a334c3a79a841120a98f64eea9f579564cb691" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" dependencies = [ "cfg-if", "indoc", @@ -1624,9 +1670,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.4" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8afbaf3abd7325e08f35ffb8deb5892046fcb2608b703db6a583a5ba4cea01e" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" dependencies = [ "once_cell", "target-lexicon", @@ -1634,9 +1680,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.4" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec15a5ba277339d04763f4c23d85987a5b08cbb494860be141e6a10a8eb88022" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" dependencies = [ "libc", "pyo3-build-config", @@ -1644,34 +1690,34 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.4" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e0f01b5364bcfbb686a52fc4181d412b708a68ed20c330db9fc8d2c2bf5a43" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] name = "pyo3-macros-backend" -version = "0.22.4" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a09b550200e1e5ed9176976d0060cbc2ea82dc8515da07885e7b8153a85caacb" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" dependencies = [ "heck", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1699,8 +1745,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.1", + "zerocopy 0.8.18", ] [[package]] @@ -1710,7 +1767,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.1", ] [[package]] @@ -1719,7 +1786,17 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3" +dependencies = [ + "getrandom 0.3.1", + "zerocopy 0.8.18", ] [[package]] @@ -1750,11 +1827,11 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "redox_syscall" -version = "0.5.5" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", ] [[package]] @@ -1763,9 +1840,9 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1782,9 +1859,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -1794,9 +1871,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -1805,9 +1882,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "region" @@ -1839,7 +1916,7 @@ dependencies = [ "pmutil", "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] @@ -1859,15 +1936,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1895,8 +1972,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5" +source = "git+https://github.com/RustPython/Parser.git?rev=d2f137b372ec08ce4a243564a80f8f9153c45a23#d2f137b372ec08ce4a243564a80f8f9153c45a23" dependencies = [ "is-macro", "malachite-bigint", @@ -1910,14 +1986,15 @@ name = "rustpython-codegen" version = "0.4.0" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.8.0", "indexmap", "insta", - "itertools 0.11.0", + "itertools 0.14.0", "log", "num-complex", "num-traits", "rustpython-ast", + "rustpython-common", "rustpython-compiler-core", "rustpython-parser", "rustpython-parser-core", @@ -1928,9 +2005,9 @@ name = "rustpython-common" version = "0.4.0" dependencies = [ "ascii", - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", - "itertools 0.11.0", + "itertools 0.14.0", "libc", "lock_api", "malachite-base", @@ -1941,9 +2018,9 @@ dependencies = [ "once_cell", "parking_lot", "radium", - "rand", + "rand 0.9.0", "rustpython-format", - "siphasher", + "siphasher 0.3.11", "volatile", "widestring", "windows-sys 0.52.0", @@ -1962,8 +2039,8 @@ dependencies = [ name = "rustpython-compiler-core" version = "0.4.0" dependencies = [ - "bitflags 2.6.0", - "itertools 0.11.0", + "bitflags 2.8.0", + "itertools 0.14.0", "lz4_flex", "malachite-bigint", "num-complex", @@ -1977,14 +2054,14 @@ version = "0.4.0" dependencies = [ "rustpython-compiler", "rustpython-derive-impl", - "syn 1.0.109", + "syn 2.0.98", ] [[package]] name = "rustpython-derive-impl" version = "0.4.0" dependencies = [ - "itertools 0.11.0", + "itertools 0.14.0", "maplit", "once_cell", "proc-macro2", @@ -1992,9 +2069,9 @@ dependencies = [ "rustpython-compiler-core", "rustpython-doc", "rustpython-parser-core", - "syn 1.0.109", + "syn 2.0.98", "syn-ext", - "textwrap 0.15.2", + "textwrap 0.16.1", ] [[package]] @@ -2008,10 +2085,9 @@ dependencies = [ [[package]] name = "rustpython-format" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0389039b132ad8e350552d771270ccd03186985696764bcee2239694e7839942" +source = "git+https://github.com/RustPython/Parser.git?rev=d2f137b372ec08ce4a243564a80f8f9153c45a23#d2f137b372ec08ce4a243564a80f8f9153c45a23" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "itertools 0.11.0", "malachite-bigint", "num-traits", @@ -2030,14 +2106,13 @@ dependencies = [ "num-traits", "rustpython-compiler-core", "rustpython-derive", - "thiserror", + "thiserror 2.0.11", ] [[package]] name = "rustpython-literal" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8304be3cae00232a1721a911033e55877ca3810215f66798e964a2d8d22281d" +source = "git+https://github.com/RustPython/Parser.git?rev=d2f137b372ec08ce4a243564a80f8f9153c45a23#d2f137b372ec08ce4a243564a80f8f9153c45a23" dependencies = [ "hexf-parse", "is-macro", @@ -2049,8 +2124,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb" +source = "git+https://github.com/RustPython/Parser.git?rev=d2f137b372ec08ce4a243564a80f8f9153c45a23#d2f137b372ec08ce4a243564a80f8f9153c45a23" dependencies = [ "anyhow", "is-macro", @@ -2073,8 +2147,7 @@ dependencies = [ [[package]] name = "rustpython-parser-core" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad" +source = "git+https://github.com/RustPython/Parser.git?rev=d2f137b372ec08ce4a243564a80f8f9153c45a23#d2f137b372ec08ce4a243564a80f8f9153c45a23" dependencies = [ "is-macro", "memchr", @@ -2084,8 +2157,7 @@ dependencies = [ [[package]] name = "rustpython-parser-vendored" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6" +source = "git+https://github.com/RustPython/Parser.git?rev=d2f137b372ec08ce4a243564a80f8f9153c45a23#d2f137b372ec08ce4a243564a80f8f9153c45a23" dependencies = [ "memchr", "once_cell", @@ -2104,7 +2176,8 @@ dependencies = [ name = "rustpython-sre_engine" version = "0.4.0" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", + "criterion", "num_enum", "optional", ] @@ -2131,18 +2204,18 @@ dependencies = [ "gethostname", "hex", "indexmap", - "itertools 0.11.0", + "itertools 0.14.0", "junction", "libc", "libsqlite3-sys", - "libz-sys", + "libz-rs-sys", "mac_address", "malachite-bigint", "md-5", "memchr", "memmap2", "mt19937", - "nix 0.29.0", + "nix", "num-complex", "num-integer", "num-traits", @@ -2155,8 +2228,7 @@ dependencies = [ "parking_lot", "paste", "puruspe", - "rand", - "rand_core", + "rand 0.9.0", "rustix", "rustpython-common", "rustpython-derive", @@ -2191,29 +2263,31 @@ version = "0.4.0" dependencies = [ "ahash", "ascii", - "bitflags 2.6.0", + "bitflags 2.8.0", "bstr", "caseless", "cfg-if", "chrono", "crossbeam-utils", + "errno", "exitcode", "flame", "flamer", - "getrandom", + "getrandom 0.3.1", "glob", "half 2.4.1", "hex", "indexmap", "is-macro", - "itertools 0.11.0", + "itertools 0.14.0", "junction", "libc", + "libloading", "log", "malachite-bigint", "memchr", "memoffset", - "nix 0.29.0", + "nix", "num-complex", "num-integer", "num-traits", @@ -2223,7 +2297,7 @@ dependencies = [ "optional", "parking_lot", "paste", - "rand", + "rand 0.9.0", "result-like", "rustc_version", "rustix", @@ -2245,7 +2319,7 @@ dependencies = [ "static_assertions", "strum", "strum_macros", - "thiserror", + "thiserror 2.0.11", "thread_local", "timsort", "uname", @@ -2267,6 +2341,7 @@ name = "rustpython_wasm" version = "0.4.0" dependencies = [ "console_error_panic_hook", + "getrandom 0.2.15", "js-sys", "rustpython-common", "rustpython-parser", @@ -2282,17 +2357,17 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "rustyline" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7803e8936da37efd9b6d4478277f4b2b9bb5cdb37a113e8d63222e58da647e63" +checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.8.0", "cfg-if", "clipboard-win", "fd-lock", @@ -2300,19 +2375,19 @@ dependencies = [ "libc", "log", "memchr", - "nix 0.28.0", + "nix", "radix_trie", "unicode-segmentation", - "unicode-width", + "unicode-width 0.2.0", "utf8parse", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "same-file" @@ -2325,9 +2400,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -2340,15 +2415,15 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.23" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] @@ -2377,20 +2452,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -2438,9 +2513,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "similar" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" @@ -2448,6 +2523,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slice-group-by" version = "0.3.1" @@ -2456,15 +2537,15 @@ checksum = "826167069c09b99d56f31e9ae5c99049e932a98c9dc2dac47645b08dbbf76ba7" [[package]] name = "smallvec" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -2484,21 +2565,21 @@ checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" [[package]] name = "strum" -version = "0.26.3" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" [[package]] name = "strum_macros" -version = "0.26.4" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] @@ -2520,9 +2601,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.77" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", @@ -2531,11 +2612,13 @@ dependencies = [ [[package]] name = "syn-ext" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b86cb2b68c5b3c078cac02588bc23f3c04bb828c5d3aedd17980876ec6a7be6" +checksum = "b126de4ef6c2a628a68609dd00733766c3b015894698a438ebdf374933fc31d1" dependencies = [ - "syn 1.0.109", + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] @@ -2589,33 +2672,53 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" dependencies = [ - "unicode-width", + "unicode-width 0.1.14", ] [[package]] name = "textwrap" -version = "0.15.2" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7b3e525a49ec206798b40326a44121291b530c963cfb01018f63e135bac543d" +checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ - "thiserror-impl", + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] [[package]] @@ -2666,9 +2769,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -2691,9 +2794,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.17.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "ucd" @@ -2834,9 +2937,9 @@ checksum = "623f59e6af2a98bdafeb93fa277ac8e1e40440973001ca15cf4ae1541cd16d56" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-normalization" @@ -2859,6 +2962,18 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unicode_names2" version = "1.3.0" @@ -2878,7 +2993,7 @@ dependencies = [ "getopts", "log", "phf_codegen", - "rand", + "rand 0.8.5", ] [[package]] @@ -2895,13 +3010,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" dependencies = [ "atomic", - "getrandom", - "rand", ] [[package]] @@ -2944,48 +3057,59 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2993,28 +3117,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -3281,12 +3408,12 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winreg" -version = "0.52.0" +version = "0.55.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" +checksum = "cb5a765337c50e9ec252c2069be9bf91c7df47afb103b642ba3a53bf8101be97" dependencies = [ "cfg-if", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -3295,11 +3422,20 @@ version = "0.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "xml-rs" -version = "0.8.22" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af4e2e2f7cba5a093896c1e150fbfe177d1883e7448200efb81d40b9d339ef26" +checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" [[package]] name = "zerocopy" @@ -3308,7 +3444,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2" +dependencies = [ + "zerocopy-derive 0.8.18", ] [[package]] @@ -3319,5 +3464,22 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn 2.0.98", ] + +[[package]] +name = "zerocopy-derive" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] + +[[package]] +name = "zlib-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" diff --git a/Cargo.toml b/Cargo.toml index af27add260..33fb96e84c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ repository.workspace = true license.workspace = true [features] -default = ["threading", "stdlib", "zlib", "importlib"] +default = ["threading", "stdlib", "importlib"] importlib = ["rustpython-vm/importlib"] encodings = ["rustpython-vm/encodings"] stdlib = ["rustpython-stdlib", "rustpython-pylib", "encodings"] @@ -18,7 +18,6 @@ flame-it = ["rustpython-vm/flame-it", "flame", "flamescope"] freeze-stdlib = ["stdlib", "rustpython-vm/freeze-stdlib", "rustpython-pylib?/freeze-stdlib"] jit = ["rustpython-vm/jit"] threading = ["rustpython-vm/threading", "rustpython-stdlib/threading"] -zlib = ["stdlib", "rustpython-stdlib/zlib"] bz2 = ["stdlib", "rustpython-stdlib/bz2"] sqlite = ["rustpython-stdlib/sqlite"] ssl = ["rustpython-stdlib/ssl"] @@ -47,7 +46,7 @@ libc = { workspace = true } rustyline = { workspace = true } [dev-dependencies] -criterion = { version = "0.3.5", features = ["html_reports"] } +criterion = { workspace = true } pyo3 = { version = "0.22", features = ["auto-initialize"] } [[bench]] @@ -103,8 +102,8 @@ members = [ [workspace.package] version = "0.4.0" authors = ["RustPython Team"] -edition = "2021" -rust-version = "1.83.0" +edition = "2024" +rust-version = "1.85.0" repository = "https://github.com/RustPython/RustPython" license = "MIT" @@ -122,16 +121,16 @@ rustpython-stdlib = { path = "stdlib", default-features = false, version = "0.4. rustpython-sre_engine = { path = "vm/sre_engine", version = "0.4.0" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", tag = "0.3.0", version = "0.3.0" } -rustpython-literal = { version = "0.4.0" } -rustpython-parser-core = { version = "0.4.0" } -rustpython-parser = { version = "0.4.0" } -rustpython-ast = { version = "0.4.0" } -rustpython-format= { version = "0.4.0" } -# rustpython-literal = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "00d2f1d1a7522ef9c85c10dfa5f0bb7178dee655" } -# rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "00d2f1d1a7522ef9c85c10dfa5f0bb7178dee655" } -# rustpython-parser = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "00d2f1d1a7522ef9c85c10dfa5f0bb7178dee655" } -# rustpython-ast = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "00d2f1d1a7522ef9c85c10dfa5f0bb7178dee655" } -# rustpython-format = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "00d2f1d1a7522ef9c85c10dfa5f0bb7178dee655" } +# rustpython-literal = { version = "0.4.0" } +# rustpython-parser-core = { version = "0.4.0" } +# rustpython-parser = { version = "0.4.0" } +# rustpython-ast = { version = "0.4.0" } +# rustpython-format= { version = "0.4.0" } +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "d2f137b372ec08ce4a243564a80f8f9153c45a23" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "d2f137b372ec08ce4a243564a80f8f9153c45a23" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "d2f137b372ec08ce4a243564a80f8f9153c45a23" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "d2f137b372ec08ce4a243564a80f8f9153c45a23" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", version = "0.4.0", rev = "d2f137b372ec08ce4a243564a80f8f9153c45a23" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } @@ -139,55 +138,57 @@ rustpython-format= { version = "0.4.0" } # rustpython-format = { path = "../RustPython-parser/format" } ahash = "0.8.11" -ascii = "1.0" -bitflags = "2.4.1" +ascii = "1.1" +bitflags = "2.4.2" bstr = "1" cfg-if = "1.0" -chrono = "0.4.37" -crossbeam-utils = "0.8.19" +chrono = "0.4.39" +criterion = { version = "0.3.5", features = ["html_reports"] } +crossbeam-utils = "0.8.21" flame = "0.2.2" -getrandom = "0.2.12" +getrandom = "0.3" glob = "0.3" hex = "0.4.3" indexmap = { version = "2.2.6", features = ["std"] } insta = "1.38.0" -itertools = "0.11.0" -is-macro = "0.3.0" -junction = "1.0.0" -libc = "0.2.153" -log = "0.4.16" +itertools = "0.14.0" +is-macro = "0.3.7" +junction = "1.2.0" +libc = "0.2.169" +log = "0.4.25" nix = { version = "0.29", features = ["fs", "user", "process", "term", "time", "signal", "ioctl", "socket", "sched", "zerocopy", "dir", "hostname", "net", "poll"] } -malachite-bigint = "0.2.0" -malachite-q = "0.4.4" -malachite-base = "0.4.4" -memchr = "2.7.2" -num-complex = "0.4.0" -num-integer = "0.1.44" +malachite-bigint = "0.2.3" +malachite-q = "0.4.22" +malachite-base = "0.4.22" +memchr = "2.7.4" +num-complex = "0.4.6" +num-integer = "0.1.46" num-traits = "0.2" num_enum = { version = "0.7", default-features = false } -once_cell = "1.19.0" -parking_lot = "0.12.1" -paste = "1.0.7" -rand = "0.8.5" +once_cell = "1.20.3" +parking_lot = "0.12.3" +paste = "1.0.15" +rand = "0.9" rustix = { version = "0.38", features = ["event"] } -rustyline = "14.0.0" +rustyline = "15.0.0" serde = { version = "1.0.133", default-features = false } -schannel = "0.1.22" +schannel = "0.1.27" static_assertions = "1.1" -strum = "0.26" -strum_macros = "0.26" -syn = "1.0.109" -thiserror = "1.0" -thread_local = "1.1.4" -unicode_names2 = "1.2.0" +strum = "0.27" +strum_macros = "0.27" +syn = "2" +thiserror = "2.0" +thread_local = "1.1.8" +unicode_names2 = "1.3.0" widestring = "1.1.0" windows-sys = "0.52.0" -wasm-bindgen = "0.2.92" +wasm-bindgen = "0.2.100" # Lints [workspace.lints.rust] unsafe_code = "allow" +unsafe_op_in_unsafe_fn = "deny" [workspace.lints.clippy] perf = "warn" diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 7c79a011ba..aa7d99eef3 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -25,7 +25,7 @@ RustPython requires the following: stable version: `rustup update stable` - If you do not have Rust installed, use [rustup](https://rustup.rs/) to do so. -- CPython version 3.12 or higher +- CPython version 3.13 or higher - CPython can be installed by your operating system's package manager, from the [Python website](https://www.python.org/downloads/), or using a third-party distribution, such as diff --git a/LICENSE b/LICENSE index 7213274e0f..e2aa2ed952 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 RustPython Team +Copyright (c) 2025 RustPython Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Lib/_colorize.py b/Lib/_colorize.py new file mode 100644 index 0000000000..70acfd4ad0 --- /dev/null +++ b/Lib/_colorize.py @@ -0,0 +1,67 @@ +import io +import os +import sys + +COLORIZE = True + + +class ANSIColors: + BOLD_GREEN = "\x1b[1;32m" + BOLD_MAGENTA = "\x1b[1;35m" + BOLD_RED = "\x1b[1;31m" + GREEN = "\x1b[32m" + GREY = "\x1b[90m" + MAGENTA = "\x1b[35m" + RED = "\x1b[31m" + RESET = "\x1b[0m" + YELLOW = "\x1b[33m" + + +NoColors = ANSIColors() + +for attr in dir(NoColors): + if not attr.startswith("__"): + setattr(NoColors, attr, "") + + +def get_colors(colorize: bool = False, *, file=None) -> ANSIColors: + if colorize or can_colorize(file=file): + return ANSIColors() + else: + return NoColors + + +def can_colorize(*, file=None) -> bool: + if file is None: + file = sys.stdout + + if not sys.flags.ignore_environment: + if os.environ.get("PYTHON_COLORS") == "0": + return False + if os.environ.get("PYTHON_COLORS") == "1": + return True + if os.environ.get("NO_COLOR"): + return False + if not COLORIZE: + return False + if os.environ.get("FORCE_COLOR"): + return True + if os.environ.get("TERM") == "dumb": + return False + + if not hasattr(file, "fileno"): + return False + + if sys.platform == "win32": + try: + import nt + + if not nt._supports_virtual_terminal(): + return False + except (ImportError, AttributeError): + return False + + try: + return os.isatty(file.fileno()) + except io.UnsupportedOperation: + return file.isatty() diff --git a/Lib/ctypes/test/test_numbers.py b/Lib/ctypes/test/test_numbers.py index a7696376a5..a5c661b0e9 100644 --- a/Lib/ctypes/test/test_numbers.py +++ b/Lib/ctypes/test/test_numbers.py @@ -147,10 +147,10 @@ def test_alignments(self): # alignment of the type... self.assertEqual((code, alignment(t)), - (code, align)) + (code, align)) # and alignment of an instance self.assertEqual((code, alignment(t())), - (code, align)) + (code, align)) def test_int_from_address(self): from array import array diff --git a/Lib/difflib.py b/Lib/difflib.py index ba0b256969..3425e438c9 100644 --- a/Lib/difflib.py +++ b/Lib/difflib.py @@ -1200,25 +1200,6 @@ def context_diff(a, b, fromfile='', tofile='', strings for 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'. The modification times are normally expressed in the ISO 8601 format. If not specified, the strings default to blanks. - - Example: - - >>> print(''.join(context_diff('one\ntwo\nthree\nfour\n'.splitlines(True), - ... 'zero\none\ntree\nfour\n'.splitlines(True), 'Original', 'Current')), - ... end="") - *** Original - --- Current - *************** - *** 1,4 **** - one - ! two - ! three - four - --- 1,4 ---- - + zero - one - ! tree - four """ _check_types(a, b, fromfile, tofile, fromfiledate, tofiledate, lineterm) diff --git a/Lib/gzip.py b/Lib/gzip.py index 5b20e5ba69..1a3c82ce7e 100644 --- a/Lib/gzip.py +++ b/Lib/gzip.py @@ -15,12 +15,16 @@ FTEXT, FHCRC, FEXTRA, FNAME, FCOMMENT = 1, 2, 4, 8, 16 -READ, WRITE = 1, 2 +READ = 'rb' +WRITE = 'wb' _COMPRESS_LEVEL_FAST = 1 _COMPRESS_LEVEL_TRADEOFF = 6 _COMPRESS_LEVEL_BEST = 9 +READ_BUFFER_SIZE = 128 * 1024 +_WRITE_BUFFER_SIZE = 4 * io.DEFAULT_BUFFER_SIZE + def open(filename, mode="rb", compresslevel=_COMPRESS_LEVEL_BEST, encoding=None, errors=None, newline=None): @@ -118,6 +122,21 @@ class BadGzipFile(OSError): """Exception raised in some cases for invalid gzip files.""" +class _WriteBufferStream(io.RawIOBase): + """Minimal object to pass WriteBuffer flushes into GzipFile""" + def __init__(self, gzip_file): + self.gzip_file = gzip_file + + def write(self, data): + return self.gzip_file._write_raw(data) + + def seekable(self): + return False + + def writable(self): + return True + + class GzipFile(_compression.BaseStream): """The GzipFile class simulates most of the methods of a file object with the exception of the truncate() method. @@ -160,9 +179,10 @@ def __init__(self, filename=None, mode=None, and 9 is slowest and produces the most compression. 0 is no compression at all. The default is 9. - The mtime argument is an optional numeric timestamp to be written - to the last modification time field in the stream when compressing. - If omitted or None, the current time is used. + The optional mtime argument is the timestamp requested by gzip. The time + is in Unix format, i.e., seconds since 00:00:00 UTC, January 1, 1970. + If mtime is omitted or None, the current time is used. Use mtime = 0 + to generate a compressed stream that does not depend on creation time. """ @@ -182,6 +202,7 @@ def __init__(self, filename=None, mode=None, if mode is None: mode = getattr(fileobj, 'mode', 'rb') + if mode.startswith('r'): self.mode = READ raw = _GzipReader(fileobj) @@ -204,6 +225,9 @@ def __init__(self, filename=None, mode=None, zlib.DEF_MEM_LEVEL, 0) self._write_mtime = mtime + self._buffer_size = _WRITE_BUFFER_SIZE + self._buffer = io.BufferedWriter(_WriteBufferStream(self), + buffer_size=self._buffer_size) else: raise ValueError("Invalid mode: {!r}".format(mode)) @@ -212,14 +236,6 @@ def __init__(self, filename=None, mode=None, if self.mode == WRITE: self._write_gzip_header(compresslevel) - @property - def filename(self): - import warnings - warnings.warn("use the name attribute", DeprecationWarning, 2) - if self.mode == WRITE and self.name[-3:] != ".gz": - return self.name + ".gz" - return self.name - @property def mtime(self): """Last modification time read from stream, or None""" @@ -237,6 +253,11 @@ def _init_write(self, filename): self.bufsize = 0 self.offset = 0 # Current file offset for seek(), tell(), etc + def tell(self): + self._check_not_closed() + self._buffer.flush() + return super().tell() + def _write_gzip_header(self, compresslevel): self.fileobj.write(b'\037\213') # magic header self.fileobj.write(b'\010') # compression method @@ -278,6 +299,10 @@ def write(self,data): if self.fileobj is None: raise ValueError("write() on closed GzipFile object") + return self._buffer.write(data) + + def _write_raw(self, data): + # Called by our self._buffer underlying WriteBufferStream. if isinstance(data, (bytes, bytearray)): length = len(data) else: @@ -326,11 +351,11 @@ def closed(self): def close(self): fileobj = self.fileobj - if fileobj is None: + if fileobj is None or self._buffer.closed: return - self.fileobj = None try: if self.mode == WRITE: + self._buffer.flush() fileobj.write(self.compress.flush()) write32u(fileobj, self.crc) # self.size may exceed 2 GiB, or even 4 GiB @@ -338,6 +363,7 @@ def close(self): elif self.mode == READ: self._buffer.close() finally: + self.fileobj = None myfileobj = self.myfileobj if myfileobj: self.myfileobj = None @@ -346,6 +372,7 @@ def close(self): def flush(self,zlib_mode=zlib.Z_SYNC_FLUSH): self._check_not_closed() if self.mode == WRITE: + self._buffer.flush() # Ensure the compressor's buffer is flushed self.fileobj.write(self.compress.flush(zlib_mode)) self.fileobj.flush() @@ -376,6 +403,9 @@ def seekable(self): def seek(self, offset, whence=io.SEEK_SET): if self.mode == WRITE: + self._check_not_closed() + # Flush buffer to ensure validity of self.offset + self._buffer.flush() if whence != io.SEEK_SET: if whence == io.SEEK_CUR: offset = self.offset + offset @@ -384,10 +414,10 @@ def seek(self, offset, whence=io.SEEK_SET): if offset < self.offset: raise OSError('Negative seek in write mode') count = offset - self.offset - chunk = b'\0' * 1024 - for i in range(count // 1024): + chunk = b'\0' * self._buffer_size + for i in range(count // self._buffer_size): self.write(chunk) - self.write(b'\0' * (count % 1024)) + self.write(b'\0' * (count % self._buffer_size)) elif self.mode == READ: self._check_not_closed() return self._buffer.seek(offset, whence) @@ -454,7 +484,7 @@ def _read_gzip_header(fp): class _GzipReader(_compression.DecompressReader): def __init__(self, fp): - super().__init__(_PaddedFile(fp), zlib.decompressobj, + super().__init__(_PaddedFile(fp), zlib._ZlibDecompressor, wbits=-zlib.MAX_WBITS) # Set flag indicating start of a new member self._new_member = True @@ -502,12 +532,13 @@ def read(self, size=-1): self._new_member = False # Read a chunk of data from the file - buf = self._fp.read(io.DEFAULT_BUFFER_SIZE) + if self._decompressor.needs_input: + buf = self._fp.read(READ_BUFFER_SIZE) + uncompress = self._decompressor.decompress(buf, size) + else: + uncompress = self._decompressor.decompress(b"", size) - uncompress = self._decompressor.decompress(buf, size) - if self._decompressor.unconsumed_tail != b"": - self._fp.prepend(self._decompressor.unconsumed_tail) - elif self._decompressor.unused_data != b"": + if self._decompressor.unused_data != b"": # Prepend the already read bytes to the fileobj so they can # be seen by _read_eof() and _read_gzip_header() self._fp.prepend(self._decompressor.unused_data) @@ -518,14 +549,11 @@ def read(self, size=-1): raise EOFError("Compressed file ended before the " "end-of-stream marker was reached") - self._add_read_data( uncompress ) + self._crc = zlib.crc32(uncompress, self._crc) + self._stream_size += len(uncompress) self._pos += len(uncompress) return uncompress - def _add_read_data(self, data): - self._crc = zlib.crc32(data, self._crc) - self._stream_size = self._stream_size + len(data) - def _read_eof(self): # We've read to the end of the file # We check that the computed CRC and size of the @@ -552,43 +580,21 @@ def _rewind(self): self._new_member = True -def _create_simple_gzip_header(compresslevel: int, - mtime = None) -> bytes: - """ - Write a simple gzip header with no extra fields. - :param compresslevel: Compresslevel used to determine the xfl bytes. - :param mtime: The mtime (must support conversion to a 32-bit integer). - :return: A bytes object representing the gzip header. - """ - if mtime is None: - mtime = time.time() - if compresslevel == _COMPRESS_LEVEL_BEST: - xfl = 2 - elif compresslevel == _COMPRESS_LEVEL_FAST: - xfl = 4 - else: - xfl = 0 - # Pack ID1 and ID2 magic bytes, method (8=deflate), header flags (no extra - # fields added to header), mtime, xfl and os (255 for unknown OS). - return struct.pack("'%(self.name, self.levelno, - self.pathname, self.lineno, self.msg) + self.pathname, self.lineno, self.msg) def getMessage(self): """ @@ -487,7 +511,7 @@ def __init__(self, *args, **kwargs): def usesTime(self): fmt = self._fmt - return fmt.find('$asctime') >= 0 or fmt.find(self.asctime_format) >= 0 + return fmt.find('$asctime') >= 0 or fmt.find(self.asctime_search) >= 0 def validate(self): pattern = Template.pattern @@ -557,6 +581,7 @@ class Formatter(object): (typically at application startup time) %(thread)d Thread ID (if available) %(threadName)s Thread name (if available) + %(taskName)s Task name (if available) %(process)d Process ID (if available) %(message)s The result of record.getMessage(), computed just as the record is emitted @@ -583,7 +608,7 @@ def __init__(self, fmt=None, datefmt=None, style='%', validate=True, *, """ if style not in _STYLES: raise ValueError('Style must be one of: %s' % ','.join( - _STYLES.keys())) + _STYLES.keys())) self._style = _STYLES[style][0](fmt, defaults=defaults) if validate: self._style.validate() @@ -808,23 +833,36 @@ def filter(self, record): Determine if a record is loggable by consulting all the filters. The default is to allow the record to be logged; any filter can veto - this and the record is then dropped. Returns a zero value if a record - is to be dropped, else non-zero. + this by returning a false value. + If a filter attached to a handler returns a log record instance, + then that instance is used in place of the original log record in + any further processing of the event by that handler. + If a filter returns any other true value, the original log record + is used in any further processing of the event by that handler. + + If none of the filters return false values, this method returns + a log record. + If any of the filters return a false value, this method returns + a false value. .. versionchanged:: 3.2 Allow filters to be just callables. + + .. versionchanged:: 3.12 + Allow filters to return a LogRecord instead of + modifying it in place. """ - rv = True for f in self.filters: if hasattr(f, 'filter'): result = f.filter(record) else: result = f(record) # assume callable - will raise if not if not result: - rv = False - break - return rv + return False + if isinstance(result, LogRecord): + record = result + return record #--------------------------------------------------------------------------- # Handler classes and functions @@ -845,8 +883,9 @@ def _removeHandlerRef(wr): if acquire and release and handlers: acquire() try: - if wr in handlers: - handlers.remove(wr) + handlers.remove(wr) + except ValueError: + pass finally: release() @@ -860,6 +899,23 @@ def _addHandlerRef(handler): finally: _releaseLock() + +def getHandlerByName(name): + """ + Get a handler with the specified *name*, or None if there isn't one with + that name. + """ + return _handlers.get(name) + + +def getHandlerNames(): + """ + Return all known handler names as an immutable set. + """ + result = set(_handlers.keys()) + return frozenset(result) + + class Handler(Filterer): """ Handler instances dispatch logging events to specific destinations. @@ -958,10 +1014,14 @@ def handle(self, record): Emission depends on filters which may have been added to the handler. Wrap the actual emission of the record with acquisition/release of - the I/O thread lock. Returns whether the filter passed the record for - emission. + the I/O thread lock. + + Returns an instance of the log record that was emitted + if it passed all filters, otherwise a false value is returned. """ rv = self.filter(record) + if isinstance(rv, LogRecord): + record = rv if rv: self.acquire() try: @@ -1032,7 +1092,7 @@ def handleError(self, record): else: # couldn't find the right stack frame, for some reason sys.stderr.write('Logged from file %s, line %s\n' % ( - record.filename, record.lineno)) + record.filename, record.lineno)) # Issue 18671: output logging message and arguments try: sys.stderr.write('Message: %r\n' @@ -1044,7 +1104,7 @@ def handleError(self, record): sys.stderr.write('Unable to print the message and arguments' ' - possible formatting error.\nUse the' ' traceback above to help find the error.\n' - ) + ) except OSError: #pragma: no cover pass # see issue 5971 finally: @@ -1136,6 +1196,8 @@ def __repr__(self): name += ' ' return '<%s %s(%s)>' % (self.__class__.__name__, name, level) + __class_getitem__ = classmethod(GenericAlias) + class FileHandler(StreamHandler): """ @@ -1459,7 +1521,7 @@ def debug(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.debug("Houston, we have a %s", "thorny problem", exc_info=1) + logger.debug("Houston, we have a %s", "thorny problem", exc_info=True) """ if self.isEnabledFor(DEBUG): self._log(DEBUG, msg, args, **kwargs) @@ -1471,7 +1533,7 @@ def info(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.info("Houston, we have a %s", "interesting problem", exc_info=1) + logger.info("Houston, we have a %s", "notable problem", exc_info=True) """ if self.isEnabledFor(INFO): self._log(INFO, msg, args, **kwargs) @@ -1483,14 +1545,14 @@ def warning(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1) + logger.warning("Houston, we have a %s", "bit of a problem", exc_info=True) """ if self.isEnabledFor(WARNING): self._log(WARNING, msg, args, **kwargs) def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) self.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -1500,7 +1562,7 @@ def error(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.error("Houston, we have a %s", "major problem", exc_info=1) + logger.error("Houston, we have a %s", "major problem", exc_info=True) """ if self.isEnabledFor(ERROR): self._log(ERROR, msg, args, **kwargs) @@ -1518,7 +1580,7 @@ def critical(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.critical("Houston, we have a %s", "major disaster", exc_info=1) + logger.critical("Houston, we have a %s", "major disaster", exc_info=True) """ if self.isEnabledFor(CRITICAL): self._log(CRITICAL, msg, args, **kwargs) @@ -1536,7 +1598,7 @@ def log(self, level, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.log(level, "We have a %s", "mysterious problem", exc_info=1) + logger.log(level, "We have a %s", "mysterious problem", exc_info=True) """ if not isinstance(level, int): if raiseExceptions: @@ -1554,33 +1616,31 @@ def findCaller(self, stack_info=False, stacklevel=1): f = currentframe() #On some versions of IronPython, currentframe() returns None if #IronPython isn't run with -X:Frames. - if f is not None: - f = f.f_back - orig_f = f - while f and stacklevel > 1: - f = f.f_back - stacklevel -= 1 - if not f: - f = orig_f - rv = "(unknown file)", 0, "(unknown function)", None - while hasattr(f, "f_code"): - co = f.f_code - filename = os.path.normcase(co.co_filename) - if filename == _srcfile: - f = f.f_back - continue - sinfo = None - if stack_info: - sio = io.StringIO() - sio.write('Stack (most recent call last):\n') + if f is None: + return "(unknown file)", 0, "(unknown function)", None + while stacklevel > 0: + next_f = f.f_back + if next_f is None: + ## We've got options here. + ## If we want to use the last (deepest) frame: + break + ## If we want to mimic the warnings module: + #return ("sys", 1, "(unknown function)", None) + ## If we want to be pedantic: + #raise ValueError("call stack is not deep enough") + f = next_f + if not _is_internal_frame(f): + stacklevel -= 1 + co = f.f_code + sinfo = None + if stack_info: + with io.StringIO() as sio: + sio.write("Stack (most recent call last):\n") traceback.print_stack(f, file=sio) sinfo = sio.getvalue() if sinfo[-1] == '\n': sinfo = sinfo[:-1] - sio.close() - rv = (co.co_filename, f.f_lineno, co.co_name, sinfo) - break - return rv + return co.co_filename, f.f_lineno, co.co_name, sinfo def makeRecord(self, name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None): @@ -1589,7 +1649,7 @@ def makeRecord(self, name, level, fn, lno, msg, args, exc_info, specialized LogRecords. """ rv = _logRecordFactory(name, level, fn, lno, msg, args, exc_info, func, - sinfo) + sinfo) if extra is not None: for key in extra: if (key in ["message", "asctime"]) or (key in rv.__dict__): @@ -1630,8 +1690,14 @@ def handle(self, record): This method is used for unpickled records received from a socket, as well as those created locally. Logger-level filtering is applied. """ - if (not self.disabled) and self.filter(record): - self.callHandlers(record) + if self.disabled: + return + maybe_record = self.filter(record) + if not maybe_record: + return + if isinstance(maybe_record, LogRecord): + record = maybe_record + self.callHandlers(record) def addHandler(self, hdlr): """ @@ -1737,7 +1803,7 @@ def isEnabledFor(self, level): is_enabled = self._cache[level] = False else: is_enabled = self._cache[level] = ( - level >= self.getEffectiveLevel() + level >= self.getEffectiveLevel() ) finally: _releaseLock() @@ -1762,13 +1828,30 @@ def getChild(self, suffix): suffix = '.'.join((self.name, suffix)) return self.manager.getLogger(suffix) + def getChildren(self): + + def _hierlevel(logger): + if logger is logger.manager.root: + return 0 + return 1 + logger.name.count('.') + + d = self.manager.loggerDict + _acquireLock() + try: + # exclude PlaceHolders - the last check is to ensure that lower-level + # descendants aren't returned - if there are placeholders, a logger's + # parent field might point to a grandparent or ancestor thereof. + return set(item for item in d.values() + if isinstance(item, Logger) and item.parent is self and + _hierlevel(item) == 1 + _hierlevel(item.parent)) + finally: + _releaseLock() + def __repr__(self): level = getLevelName(self.getEffectiveLevel()) return '<%s %s (%s)>' % (self.__class__.__name__, self.name, level) def __reduce__(self): - # In general, only the root logger will not be accessible via its name. - # However, the root logger's class has its own __reduce__ method. if getLogger(self.name) is not self: import pickle raise pickle.PicklingError('logger cannot be pickled') @@ -1848,7 +1931,7 @@ def warning(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) self.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -1902,18 +1985,11 @@ def hasHandlers(self): """ return self.logger.hasHandlers() - def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False): + def _log(self, level, msg, args, **kwargs): """ Low-level log implementation, proxied to allow nested logger adapters. """ - return self.logger._log( - level, - msg, - args, - exc_info=exc_info, - extra=extra, - stack_info=stack_info, - ) + return self.logger._log(level, msg, args, **kwargs) @property def manager(self): @@ -1932,6 +2008,8 @@ def __repr__(self): level = getLevelName(logger.getEffectiveLevel()) return '<%s %s (%s)>' % (self.__class__.__name__, logger.name, level) + __class_getitem__ = classmethod(GenericAlias) + root = RootLogger(WARNING) Logger.root = root Logger.manager = Manager(Logger.root) @@ -1971,7 +2049,7 @@ def basicConfig(**kwargs): that this argument is incompatible with 'filename' - if both are present, 'stream' is ignored. handlers If specified, this should be an iterable of already created - handlers, which will be added to the root handler. Any handler + handlers, which will be added to the root logger. Any handler in the list which does not have a formatter assigned will be assigned the formatter created in this function. force If this keyword is specified as true, any existing handlers @@ -2047,7 +2125,7 @@ def basicConfig(**kwargs): style = kwargs.pop("style", '%') if style not in _STYLES: raise ValueError('Style must be one of: %s' % ','.join( - _STYLES.keys())) + _STYLES.keys())) fs = kwargs.pop("format", _STYLES[style][1]) fmt = Formatter(fs, dfs, style) for h in handlers: @@ -2124,7 +2202,7 @@ def warning(msg, *args, **kwargs): def warn(msg, *args, **kwargs): warnings.warn("The 'warn' function is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) warning(msg, *args, **kwargs) def info(msg, *args, **kwargs): @@ -2179,7 +2257,11 @@ def shutdown(handlerList=_handlerList): if h: try: h.acquire() - h.flush() + # MemoryHandlers might not want to be flushed on close, + # but circular imports prevent us scoping this to just + # those handlers. hence the default to True. + if getattr(h, 'flushOnClose', True): + h.flush() h.close() except (OSError, ValueError): # Ignore errors which might be caused @@ -2242,7 +2324,9 @@ def _showwarning(message, category, filename, lineno, file=None, line=None): logger = getLogger("py.warnings") if not logger.handlers: logger.addHandler(NullHandler()) - logger.warning("%s", s) + # bpo-46557: Log str(s) as msg instead of logger.warning("%s", s) + # since some log aggregation tools group logs by the msg arg + logger.warning(str(s)) def captureWarnings(capture): """ diff --git a/Lib/logging/config.py b/Lib/logging/config.py index 3bc63b7862..ef04a35168 100644 --- a/Lib/logging/config.py +++ b/Lib/logging/config.py @@ -1,4 +1,4 @@ -# Copyright 2001-2019 by Vinay Sajip. All Rights Reserved. +# Copyright 2001-2023 by Vinay Sajip. All Rights Reserved. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose and without fee is hereby granted, @@ -19,18 +19,20 @@ is based on PEP 282 and comments thereto in comp.lang.python, and influenced by Apache's log4j system. -Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved. +Copyright (C) 2001-2022 Vinay Sajip. All Rights Reserved. To use, simply 'import logging' and log away! """ import errno +import functools import io import logging import logging.handlers +import os +import queue import re import struct -import sys import threading import traceback @@ -59,15 +61,24 @@ def fileConfig(fname, defaults=None, disable_existing_loggers=True, encoding=Non """ import configparser + if isinstance(fname, str): + if not os.path.exists(fname): + raise FileNotFoundError(f"{fname} doesn't exist") + elif not os.path.getsize(fname): + raise RuntimeError(f'{fname} is an empty file') + if isinstance(fname, configparser.RawConfigParser): cp = fname else: - cp = configparser.ConfigParser(defaults) - if hasattr(fname, 'readline'): - cp.read_file(fname) - else: - encoding = io.text_encoding(encoding) - cp.read(fname, encoding=encoding) + try: + cp = configparser.ConfigParser(defaults) + if hasattr(fname, 'readline'): + cp.read_file(fname) + else: + encoding = io.text_encoding(encoding) + cp.read(fname, encoding=encoding) + except configparser.ParsingError as e: + raise RuntimeError(f'{fname} is invalid: {e}') formatters = _create_formatters(cp) @@ -113,11 +124,18 @@ def _create_formatters(cp): fs = cp.get(sectname, "format", raw=True, fallback=None) dfs = cp.get(sectname, "datefmt", raw=True, fallback=None) stl = cp.get(sectname, "style", raw=True, fallback='%') + defaults = cp.get(sectname, "defaults", raw=True, fallback=None) + c = logging.Formatter class_name = cp[sectname].get("class") if class_name: c = _resolve(class_name) - f = c(fs, dfs, stl) + + if defaults is not None: + defaults = eval(defaults, vars(logging)) + f = c(fs, dfs, stl, defaults=defaults) + else: + f = c(fs, dfs, stl) formatters[form] = f return formatters @@ -296,7 +314,7 @@ def convert_with_key(self, key, value, replace=True): if replace: self[key] = result if type(result) in (ConvertingDict, ConvertingList, - ConvertingTuple): + ConvertingTuple): result.parent = self result.key = key return result @@ -305,7 +323,7 @@ def convert(self, value): result = self.configurator.convert(value) if value is not result: if type(result) in (ConvertingDict, ConvertingList, - ConvertingTuple): + ConvertingTuple): result.parent = self return result @@ -392,11 +410,9 @@ def resolve(self, s): self.importer(used) found = getattr(found, frag) return found - except ImportError: - e, tb = sys.exc_info()[1:] + except ImportError as e: v = ValueError('Cannot resolve %r: %s' % (s, e)) - v.__cause__, v.__traceback__ = e, tb - raise v + raise v from e def ext_convert(self, value): """Default converter for the ext:// protocol.""" @@ -448,8 +464,8 @@ def convert(self, value): elif not isinstance(value, ConvertingList) and isinstance(value, list): value = ConvertingList(value) value.configurator = self - elif not isinstance(value, ConvertingTuple) and\ - isinstance(value, tuple) and not hasattr(value, '_fields'): + elif not isinstance(value, ConvertingTuple) and \ + isinstance(value, tuple) and not hasattr(value, '_fields'): value = ConvertingTuple(value) value.configurator = self elif isinstance(value, str): # str for py3k @@ -469,10 +485,10 @@ def configure_custom(self, config): c = config.pop('()') if not callable(c): c = self.resolve(c) - props = config.pop('.', None) # Check for valid identifiers - kwargs = {k: config[k] for k in config if valid_ident(k)} + kwargs = {k: config[k] for k in config if (k != '.' and valid_ident(k))} result = c(**kwargs) + props = config.pop('.', None) if props: for name, value in props.items(): setattr(result, name, value) @@ -484,6 +500,33 @@ def as_tuple(self, value): value = tuple(value) return value +def _is_queue_like_object(obj): + """Check that *obj* implements the Queue API.""" + if isinstance(obj, (queue.Queue, queue.SimpleQueue)): + return True + # defer importing multiprocessing as much as possible + from multiprocessing.queues import Queue as MPQueue + if isinstance(obj, MPQueue): + return True + # Depending on the multiprocessing start context, we cannot create + # a multiprocessing.managers.BaseManager instance 'mm' to get the + # runtime type of mm.Queue() or mm.JoinableQueue() (see gh-119819). + # + # Since we only need an object implementing the Queue API, we only + # do a protocol check, but we do not use typing.runtime_checkable() + # and typing.Protocol to reduce import time (see gh-121723). + # + # Ideally, we would have wanted to simply use strict type checking + # instead of a protocol-based type checking since the latter does + # not check the method signatures. + # + # Note that only 'put_nowait' and 'get' are required by the logging + # queue handler and queue listener (see gh-124653) and that other + # methods are either optional or unused. + minimal_queue_interface = ['put_nowait', 'get'] + return all(callable(getattr(obj, method, None)) + for method in minimal_queue_interface) + class DictConfigurator(BaseConfigurator): """ Configure logging using a dictionary-like object to describe the @@ -542,7 +585,7 @@ def configure(self): for name in formatters: try: formatters[name] = self.configure_formatter( - formatters[name]) + formatters[name]) except Exception as e: raise ValueError('Unable to configure ' 'formatter %r' % name) from e @@ -566,7 +609,7 @@ def configure(self): handler.name = name handlers[name] = handler except Exception as e: - if 'target not configured yet' in str(e.__cause__): + if ' not configured yet' in str(e.__cause__): deferred.append(name) else: raise ValueError('Unable to configure handler ' @@ -669,18 +712,27 @@ def configure_formatter(self, config): dfmt = config.get('datefmt', None) style = config.get('style', '%') cname = config.get('class', None) + defaults = config.get('defaults', None) if not cname: c = logging.Formatter else: c = _resolve(cname) + kwargs = {} + + # Add defaults only if it exists. + # Prevents TypeError in custom formatter callables that do not + # accept it. + if defaults is not None: + kwargs['defaults'] = defaults + # A TypeError would be raised if "validate" key is passed in with a formatter callable # that does not accept "validate" as a parameter if 'validate' in config: # if user hasn't mentioned it, the default will be fine - result = c(fmt, dfmt, style, config['validate']) + result = c(fmt, dfmt, style, config['validate'], **kwargs) else: - result = c(fmt, dfmt, style) + result = c(fmt, dfmt, style, **kwargs) return result @@ -697,10 +749,29 @@ def add_filters(self, filterer, filters): """Add filters to a filterer from a list of names.""" for f in filters: try: - filterer.addFilter(self.config['filters'][f]) + if callable(f) or callable(getattr(f, 'filter', None)): + filter_ = f + else: + filter_ = self.config['filters'][f] + filterer.addFilter(filter_) except Exception as e: raise ValueError('Unable to add filter %r' % f) from e + def _configure_queue_handler(self, klass, **kwargs): + if 'queue' in kwargs: + q = kwargs.pop('queue') + else: + q = queue.Queue() # unbounded + + rhl = kwargs.pop('respect_handler_level', False) + lklass = kwargs.pop('listener', logging.handlers.QueueListener) + handlers = kwargs.pop('handlers', []) + + listener = lklass(q, *handlers, respect_handler_level=rhl) + handler = klass(q, **kwargs) + handler.listener = listener + return handler + def configure_handler(self, config): """Configure a handler from a dictionary.""" config_copy = dict(config) # for restoring in case of error @@ -720,28 +791,87 @@ def configure_handler(self, config): factory = c else: cname = config.pop('class') - klass = self.resolve(cname) - #Special case for handler which refers to another handler - if issubclass(klass, logging.handlers.MemoryHandler) and\ - 'target' in config: - try: - th = self.config['handlers'][config['target']] - if not isinstance(th, logging.Handler): - config.update(config_copy) # restore for deferred cfg - raise TypeError('target not configured yet') - config['target'] = th - except Exception as e: - raise ValueError('Unable to set target handler ' - '%r' % config['target']) from e - elif issubclass(klass, logging.handlers.SMTPHandler) and\ - 'mailhost' in config: + if callable(cname): + klass = cname + else: + klass = self.resolve(cname) + if issubclass(klass, logging.handlers.MemoryHandler): + if 'flushLevel' in config: + config['flushLevel'] = logging._checkLevel(config['flushLevel']) + if 'target' in config: + # Special case for handler which refers to another handler + try: + tn = config['target'] + th = self.config['handlers'][tn] + if not isinstance(th, logging.Handler): + config.update(config_copy) # restore for deferred cfg + raise TypeError('target not configured yet') + config['target'] = th + except Exception as e: + raise ValueError('Unable to set target handler %r' % tn) from e + elif issubclass(klass, logging.handlers.QueueHandler): + # Another special case for handler which refers to other handlers + # if 'handlers' not in config: + # raise ValueError('No handlers specified for a QueueHandler') + if 'queue' in config: + qspec = config['queue'] + + if isinstance(qspec, str): + q = self.resolve(qspec) + if not callable(q): + raise TypeError('Invalid queue specifier %r' % qspec) + config['queue'] = q() + elif isinstance(qspec, dict): + if '()' not in qspec: + raise TypeError('Invalid queue specifier %r' % qspec) + config['queue'] = self.configure_custom(dict(qspec)) + elif not _is_queue_like_object(qspec): + raise TypeError('Invalid queue specifier %r' % qspec) + + if 'listener' in config: + lspec = config['listener'] + if isinstance(lspec, type): + if not issubclass(lspec, logging.handlers.QueueListener): + raise TypeError('Invalid listener specifier %r' % lspec) + else: + if isinstance(lspec, str): + listener = self.resolve(lspec) + if isinstance(listener, type) and \ + not issubclass(listener, logging.handlers.QueueListener): + raise TypeError('Invalid listener specifier %r' % lspec) + elif isinstance(lspec, dict): + if '()' not in lspec: + raise TypeError('Invalid listener specifier %r' % lspec) + listener = self.configure_custom(dict(lspec)) + else: + raise TypeError('Invalid listener specifier %r' % lspec) + if not callable(listener): + raise TypeError('Invalid listener specifier %r' % lspec) + config['listener'] = listener + if 'handlers' in config: + hlist = [] + try: + for hn in config['handlers']: + h = self.config['handlers'][hn] + if not isinstance(h, logging.Handler): + config.update(config_copy) # restore for deferred cfg + raise TypeError('Required handler %r ' + 'is not configured yet' % hn) + hlist.append(h) + except Exception as e: + raise ValueError('Unable to set required handler %r' % hn) from e + config['handlers'] = hlist + elif issubclass(klass, logging.handlers.SMTPHandler) and \ + 'mailhost' in config: config['mailhost'] = self.as_tuple(config['mailhost']) - elif issubclass(klass, logging.handlers.SysLogHandler) and\ - 'address' in config: + elif issubclass(klass, logging.handlers.SysLogHandler) and \ + 'address' in config: config['address'] = self.as_tuple(config['address']) - factory = klass - props = config.pop('.', None) - kwargs = {k: config[k] for k in config if valid_ident(k)} + if issubclass(klass, logging.handlers.QueueHandler): + factory = functools.partial(self._configure_queue_handler, klass) + else: + factory = klass + kwargs = {k: config[k] for k in config if (k != '.' and valid_ident(k))} try: result = factory(**kwargs) except TypeError as te: @@ -759,6 +889,7 @@ def configure_handler(self, config): result.setLevel(logging._checkLevel(level)) if filters: self.add_filters(result, filters) + props = config.pop('.', None) if props: for name, value in props.items(): setattr(result, name, value) @@ -794,6 +925,7 @@ def configure_logger(self, name, config, incremental=False): """Configure a non-root logger from a dictionary.""" logger = logging.getLogger(name) self.common_logger_config(logger, config, incremental) + logger.disabled = False propagate = config.get('propagate', None) if propagate is not None: logger.propagate = propagate diff --git a/Lib/logging/handlers.py b/Lib/logging/handlers.py index 61a39958c0..bf42ea1103 100644 --- a/Lib/logging/handlers.py +++ b/Lib/logging/handlers.py @@ -187,15 +187,18 @@ def shouldRollover(self, record): Basically, see if the supplied record would cause the file to exceed the size limit we have. """ - # See bpo-45401: Never rollover anything other than regular files - if os.path.exists(self.baseFilename) and not os.path.isfile(self.baseFilename): - return False if self.stream is None: # delay was set... self.stream = self._open() if self.maxBytes > 0: # are we rolling over? + pos = self.stream.tell() + if not pos: + # gh-116263: Never rollover an empty file + return False msg = "%s\n" % self.format(record) - self.stream.seek(0, 2) #due to non-posix-compliant Windows feature - if self.stream.tell() + len(msg) >= self.maxBytes: + if pos + len(msg) >= self.maxBytes: + # See bpo-45401: Never rollover anything other than regular files + if os.path.exists(self.baseFilename) and not os.path.isfile(self.baseFilename): + return False return True return False @@ -232,19 +235,19 @@ def __init__(self, filename, when='h', interval=1, backupCount=0, if self.when == 'S': self.interval = 1 # one second self.suffix = "%Y-%m-%d_%H-%M-%S" - self.extMatch = r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}(\.\w+)?$" + extMatch = r"(?= self.rolloverAt: + # See #89564: Never rollover anything other than regular files + if os.path.exists(self.baseFilename) and not os.path.isfile(self.baseFilename): + # The file is not a regular file, so do not rollover, but do + # set the next rollover time to avoid repeated checks. + self.rolloverAt = self.computeRollover(t) + return False + return True return False @@ -365,32 +382,28 @@ def getFilesToDelete(self): dirName, baseName = os.path.split(self.baseFilename) fileNames = os.listdir(dirName) result = [] - # See bpo-44753: Don't use the extension when computing the prefix. - n, e = os.path.splitext(baseName) - prefix = n + '.' - plen = len(prefix) - for fileName in fileNames: - if self.namer is None: - # Our files will always start with baseName - if not fileName.startswith(baseName): - continue - else: - # Our files could be just about anything after custom naming, but - # likely candidates are of the form - # foo.log.DATETIME_SUFFIX or foo.DATETIME_SUFFIX.log - if (not fileName.startswith(baseName) and fileName.endswith(e) and - len(fileName) > (plen + 1) and not fileName[plen+1].isdigit()): - continue - - if fileName[:plen] == prefix: - suffix = fileName[plen:] - # See bpo-45628: The date/time suffix could be anywhere in the - # filename - parts = suffix.split('.') - for part in parts: - if self.extMatch.match(part): + if self.namer is None: + prefix = baseName + '.' + plen = len(prefix) + for fileName in fileNames: + if fileName[:plen] == prefix: + suffix = fileName[plen:] + if self.extMatch.fullmatch(suffix): + result.append(os.path.join(dirName, fileName)) + else: + for fileName in fileNames: + # Our files could be just about anything after custom naming, + # but they should contain the datetime suffix. + # Try to find the datetime suffix in the file name and verify + # that the file name can be generated by this handler. + m = self.extMatch.search(fileName) + while m: + dfn = self.namer(self.baseFilename + "." + m[0]) + if os.path.basename(dfn) == fileName: result.append(os.path.join(dirName, fileName)) break + m = self.extMatch.search(fileName, m.start() + 1) + if len(result) < self.backupCount: result = [] else: @@ -406,17 +419,14 @@ def doRollover(self): then we have to get a list of matching filenames, sort them and remove the one with the oldest suffix. """ - if self.stream: - self.stream.close() - self.stream = None # get the time that this sequence started at and make it a TimeTuple currentTime = int(time.time()) - dstNow = time.localtime(currentTime)[-1] t = self.rolloverAt - self.interval if self.utc: timeTuple = time.gmtime(t) else: timeTuple = time.localtime(t) + dstNow = time.localtime(currentTime)[-1] dstThen = timeTuple[-1] if dstNow != dstThen: if dstNow: @@ -427,26 +437,19 @@ def doRollover(self): dfn = self.rotation_filename(self.baseFilename + "." + time.strftime(self.suffix, timeTuple)) if os.path.exists(dfn): - os.remove(dfn) + # Already rolled over. + return + + if self.stream: + self.stream.close() + self.stream = None self.rotate(self.baseFilename, dfn) if self.backupCount > 0: for s in self.getFilesToDelete(): os.remove(s) if not self.delay: self.stream = self._open() - newRolloverAt = self.computeRollover(currentTime) - while newRolloverAt <= currentTime: - newRolloverAt = newRolloverAt + self.interval - #If DST changes and midnight or weekly rollover, adjust for this. - if (self.when == 'MIDNIGHT' or self.when.startswith('W')) and not self.utc: - dstAtRollover = time.localtime(newRolloverAt)[-1] - if dstNow != dstAtRollover: - if not dstNow: # DST kicks in before next rollover, so we need to deduct an hour - addend = -3600 - else: # DST bows out before next rollover, so we need to add an hour - addend = 3600 - newRolloverAt += addend - self.rolloverAt = newRolloverAt + self.rolloverAt = self.computeRollover(currentTime) class WatchedFileHandler(logging.FileHandler): """ @@ -800,7 +803,7 @@ class SysLogHandler(logging.Handler): "panic": LOG_EMERG, # DEPRECATED "warn": LOG_WARNING, # DEPRECATED "warning": LOG_WARNING, - } + } facility_names = { "auth": LOG_AUTH, @@ -827,12 +830,10 @@ class SysLogHandler(logging.Handler): "local5": LOG_LOCAL5, "local6": LOG_LOCAL6, "local7": LOG_LOCAL7, - } + } - #The map below appears to be trivially lowercasing the key. However, - #there's more to it than meets the eye - in some locales, lowercasing - #gives unexpected results. See SF #1524081: in the Turkish locale, - #"INFO".lower() != "info" + # Originally added to work around GH-43683. Unnecessary since GH-50043 but kept + # for backwards compatibility. priority_map = { "DEBUG" : "debug", "INFO" : "info", @@ -859,12 +860,49 @@ def __init__(self, address=('localhost', SYSLOG_UDP_PORT), self.address = address self.facility = facility self.socktype = socktype + self.socket = None + self.createSocket() + + def _connect_unixsocket(self, address): + use_socktype = self.socktype + if use_socktype is None: + use_socktype = socket.SOCK_DGRAM + self.socket = socket.socket(socket.AF_UNIX, use_socktype) + try: + self.socket.connect(address) + # it worked, so set self.socktype to the used type + self.socktype = use_socktype + except OSError: + self.socket.close() + if self.socktype is not None: + # user didn't specify falling back, so fail + raise + use_socktype = socket.SOCK_STREAM + self.socket = socket.socket(socket.AF_UNIX, use_socktype) + try: + self.socket.connect(address) + # it worked, so set self.socktype to the used type + self.socktype = use_socktype + except OSError: + self.socket.close() + raise + + def createSocket(self): + """ + Try to create a socket and, if it's not a datagram socket, connect it + to the other end. This method is called during handler initialization, + but it's not regarded as an error if the other end isn't listening yet + --- the method will be called again when emitting an event, + if there is no socket at that point. + """ + address = self.address + socktype = self.socktype if isinstance(address, str): self.unixsocket = True # Syslog server may be unavailable during handler initialisation. # C's openlog() function also ignores connection errors. - # Moreover, we ignore these errors while logging, so it not worse + # Moreover, we ignore these errors while logging, so it's not worse # to ignore it also here. try: self._connect_unixsocket(address) @@ -895,30 +933,6 @@ def __init__(self, address=('localhost', SYSLOG_UDP_PORT), self.socket = sock self.socktype = socktype - def _connect_unixsocket(self, address): - use_socktype = self.socktype - if use_socktype is None: - use_socktype = socket.SOCK_DGRAM - self.socket = socket.socket(socket.AF_UNIX, use_socktype) - try: - self.socket.connect(address) - # it worked, so set self.socktype to the used type - self.socktype = use_socktype - except OSError: - self.socket.close() - if self.socktype is not None: - # user didn't specify falling back, so fail - raise - use_socktype = socket.SOCK_STREAM - self.socket = socket.socket(socket.AF_UNIX, use_socktype) - try: - self.socket.connect(address) - # it worked, so set self.socktype to the used type - self.socktype = use_socktype - except OSError: - self.socket.close() - raise - def encodePriority(self, facility, priority): """ Encode the facility and priority. You can pass in strings or @@ -938,7 +952,10 @@ def close(self): """ self.acquire() try: - self.socket.close() + sock = self.socket + if sock: + self.socket = None + sock.close() logging.Handler.close(self) finally: self.release() @@ -978,6 +995,10 @@ def emit(self, record): # Message is a string. Convert to bytes as required by RFC 5424 msg = msg.encode('utf-8') msg = prio + msg + + if not self.socket: + self.createSocket() + if self.unixsocket: try: self.socket.send(msg) @@ -1094,7 +1115,16 @@ def __init__(self, appname, dllname=None, logtype="Application"): dllname = os.path.join(dllname[0], r'win32service.pyd') self.dllname = dllname self.logtype = logtype - self._welu.AddSourceToRegistry(appname, dllname, logtype) + # Administrative privileges are required to add a source to the registry. + # This may not be available for a user that just wants to add to an + # existing source - handle this specific case. + try: + self._welu.AddSourceToRegistry(appname, dllname, logtype) + except Exception as e: + # This will probably be a pywintypes.error. Only raise if it's not + # an "access denied" error, else let it pass + if getattr(e, 'winerror', None) != 5: # not access denied + raise self.deftype = win32evtlog.EVENTLOG_ERROR_TYPE self.typemap = { logging.DEBUG : win32evtlog.EVENTLOG_INFORMATION_TYPE, @@ -1102,10 +1132,10 @@ def __init__(self, appname, dllname=None, logtype="Application"): logging.WARNING : win32evtlog.EVENTLOG_WARNING_TYPE, logging.ERROR : win32evtlog.EVENTLOG_ERROR_TYPE, logging.CRITICAL: win32evtlog.EVENTLOG_ERROR_TYPE, - } + } except ImportError: - print("The Python Win32 extensions for NT (service, event "\ - "logging) appear not to be available.") + print("The Python Win32 extensions for NT (service, event " \ + "logging) appear not to be available.") self._welu = None def getMessageID(self, record): @@ -1348,7 +1378,7 @@ def shouldFlush(self, record): Check for buffer full or a record at the flushLevel or higher. """ return (len(self.buffer) >= self.capacity) or \ - (record.levelno >= self.flushLevel) + (record.levelno >= self.flushLevel) def setTarget(self, target): """ @@ -1366,7 +1396,7 @@ def flush(self): records to the target, if there is one. Override if you want different behaviour. - The record buffer is also cleared by this operation. + The record buffer is only cleared if a target has been set. """ self.acquire() try: @@ -1411,6 +1441,7 @@ def __init__(self, queue): """ logging.Handler.__init__(self) self.queue = queue + self.listener = None # will be set to listener if configured via dictConfig() def enqueue(self, record): """ @@ -1424,12 +1455,15 @@ def enqueue(self, record): def prepare(self, record): """ - Prepares a record for queuing. The object returned by this method is + Prepare a record for queuing. The object returned by this method is enqueued. - The base implementation formats the record to merge the message - and arguments, and removes unpickleable items from the record - in-place. + The base implementation formats the record to merge the message and + arguments, and removes unpickleable items from the record in-place. + Specifically, it overwrites the record's `msg` and + `message` attributes with the merged message (obtained by + calling the handler's `format` method), and sets the `args`, + `exc_info` and `exc_text` attributes to None. You might want to override this method if you want to convert the record to a dict or JSON string, or send a modified copy @@ -1439,7 +1473,7 @@ def prepare(self, record): # (if there's exception data), and also returns the formatted # message. We can then use this to replace the original # msg + args, as these might be unpickleable. We also zap the - # exc_info and exc_text attributes, as they are no longer + # exc_info, exc_text and stack_info attributes, as they are no longer # needed and, if not None, will typically not be pickleable. msg = self.format(record) # bpo-35726: make copy of record to avoid affecting other handlers in the chain. @@ -1449,6 +1483,7 @@ def prepare(self, record): record.args = None record.exc_info = None record.exc_text = None + record.stack_info = None return record def emit(self, record): diff --git a/Lib/smtplib.py b/Lib/smtplib.py new file mode 100644 index 0000000000..912233d817 --- /dev/null +++ b/Lib/smtplib.py @@ -0,0 +1,1109 @@ +#! /usr/bin/env python3 + +'''SMTP/ESMTP client class. + +This should follow RFC 821 (SMTP), RFC 1869 (ESMTP), RFC 2554 (SMTP +Authentication) and RFC 2487 (Secure SMTP over TLS). + +Notes: + +Please remember, when doing ESMTP, that the names of the SMTP service +extensions are NOT the same thing as the option keywords for the RCPT +and MAIL commands! + +Example: + + >>> import smtplib + >>> s=smtplib.SMTP("localhost") + >>> print(s.help()) + This is Sendmail version 8.8.4 + Topics: + HELO EHLO MAIL RCPT DATA + RSET NOOP QUIT HELP VRFY + EXPN VERB ETRN DSN + For more info use "HELP ". + To report bugs in the implementation send email to + sendmail-bugs@sendmail.org. + For local information send email to Postmaster at your site. + End of HELP info + >>> s.putcmd("vrfy","someone@here") + >>> s.getreply() + (250, "Somebody OverHere ") + >>> s.quit() +''' + +# Author: The Dragon De Monsyne +# ESMTP support, test code and doc fixes added by +# Eric S. Raymond +# Better RFC 821 compliance (MAIL and RCPT, and CRLF in data) +# by Carey Evans , for picky mail servers. +# RFC 2554 (authentication) support by Gerhard Haering . +# +# This was modified from the Python 1.5 library HTTP lib. + +import socket +import io +import re +import email.utils +import email.message +import email.generator +import base64 +import hmac +import copy +import datetime +import sys +from email.base64mime import body_encode as encode_base64 + +__all__ = ["SMTPException", "SMTPNotSupportedError", "SMTPServerDisconnected", "SMTPResponseException", + "SMTPSenderRefused", "SMTPRecipientsRefused", "SMTPDataError", + "SMTPConnectError", "SMTPHeloError", "SMTPAuthenticationError", + "quoteaddr", "quotedata", "SMTP"] + +SMTP_PORT = 25 +SMTP_SSL_PORT = 465 +CRLF = "\r\n" +bCRLF = b"\r\n" +_MAXLINE = 8192 # more than 8 times larger than RFC 821, 4.5.3 +_MAXCHALLENGE = 5 # Maximum number of AUTH challenges sent + +OLDSTYLE_AUTH = re.compile(r"auth=(.*)", re.I) + +# Exception classes used by this module. +class SMTPException(OSError): + """Base class for all exceptions raised by this module.""" + +class SMTPNotSupportedError(SMTPException): + """The command or option is not supported by the SMTP server. + + This exception is raised when an attempt is made to run a command or a + command with an option which is not supported by the server. + """ + +class SMTPServerDisconnected(SMTPException): + """Not connected to any SMTP server. + + This exception is raised when the server unexpectedly disconnects, + or when an attempt is made to use the SMTP instance before + connecting it to a server. + """ + +class SMTPResponseException(SMTPException): + """Base class for all exceptions that include an SMTP error code. + + These exceptions are generated in some instances when the SMTP + server returns an error code. The error code is stored in the + `smtp_code' attribute of the error, and the `smtp_error' attribute + is set to the error message. + """ + + def __init__(self, code, msg): + self.smtp_code = code + self.smtp_error = msg + self.args = (code, msg) + +class SMTPSenderRefused(SMTPResponseException): + """Sender address refused. + + In addition to the attributes set by on all SMTPResponseException + exceptions, this sets `sender' to the string that the SMTP refused. + """ + + def __init__(self, code, msg, sender): + self.smtp_code = code + self.smtp_error = msg + self.sender = sender + self.args = (code, msg, sender) + +class SMTPRecipientsRefused(SMTPException): + """All recipient addresses refused. + + The errors for each recipient are accessible through the attribute + 'recipients', which is a dictionary of exactly the same sort as + SMTP.sendmail() returns. + """ + + def __init__(self, recipients): + self.recipients = recipients + self.args = (recipients,) + + +class SMTPDataError(SMTPResponseException): + """The SMTP server didn't accept the data.""" + +class SMTPConnectError(SMTPResponseException): + """Error during connection establishment.""" + +class SMTPHeloError(SMTPResponseException): + """The server refused our HELO reply.""" + +class SMTPAuthenticationError(SMTPResponseException): + """Authentication error. + + Most probably the server didn't accept the username/password + combination provided. + """ + +def quoteaddr(addrstring): + """Quote a subset of the email addresses defined by RFC 821. + + Should be able to handle anything email.utils.parseaddr can handle. + """ + displayname, addr = email.utils.parseaddr(addrstring) + if (displayname, addr) == ('', ''): + # parseaddr couldn't parse it, use it as is and hope for the best. + if addrstring.strip().startswith('<'): + return addrstring + return "<%s>" % addrstring + return "<%s>" % addr + +def _addr_only(addrstring): + displayname, addr = email.utils.parseaddr(addrstring) + if (displayname, addr) == ('', ''): + # parseaddr couldn't parse it, so use it as is. + return addrstring + return addr + +# Legacy method kept for backward compatibility. +def quotedata(data): + """Quote data for email. + + Double leading '.', and change Unix newline '\\n', or Mac '\\r' into + internet CRLF end-of-line. + """ + return re.sub(r'(?m)^\.', '..', + re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data)) + +def _quote_periods(bindata): + return re.sub(br'(?m)^\.', b'..', bindata) + +def _fix_eols(data): + return re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data) + +try: + import ssl +except ImportError: + _have_ssl = False +else: + _have_ssl = True + + +class SMTP: + """This class manages a connection to an SMTP or ESMTP server. + SMTP Objects: + SMTP objects have the following attributes: + helo_resp + This is the message given by the server in response to the + most recent HELO command. + + ehlo_resp + This is the message given by the server in response to the + most recent EHLO command. This is usually multiline. + + does_esmtp + This is a True value _after you do an EHLO command_, if the + server supports ESMTP. + + esmtp_features + This is a dictionary, which, if the server supports ESMTP, + will _after you do an EHLO command_, contain the names of the + SMTP service extensions this server supports, and their + parameters (if any). + + Note, all extension names are mapped to lower case in the + dictionary. + + See each method's docstrings for details. In general, there is a + method of the same name to perform each SMTP command. There is also a + method called 'sendmail' that will do an entire mail transaction. + """ + debuglevel = 0 + + sock = None + file = None + helo_resp = None + ehlo_msg = "ehlo" + ehlo_resp = None + does_esmtp = False + default_port = SMTP_PORT + + def __init__(self, host='', port=0, local_hostname=None, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """Initialize a new instance. + + If specified, `host` is the name of the remote host to which to + connect. If specified, `port` specifies the port to which to connect. + By default, smtplib.SMTP_PORT is used. If a host is specified the + connect method is called, and if it returns anything other than a + success code an SMTPConnectError is raised. If specified, + `local_hostname` is used as the FQDN of the local host in the HELO/EHLO + command. Otherwise, the local hostname is found using + socket.getfqdn(). The `source_address` parameter takes a 2-tuple (host, + port) for the socket to bind to as its source address before + connecting. If the host is '' and port is 0, the OS default behavior + will be used. + + """ + self._host = host + self.timeout = timeout + self.esmtp_features = {} + self.command_encoding = 'ascii' + self.source_address = source_address + self._auth_challenge_count = 0 + + if host: + (code, msg) = self.connect(host, port) + if code != 220: + self.close() + raise SMTPConnectError(code, msg) + if local_hostname is not None: + self.local_hostname = local_hostname + else: + # RFC 2821 says we should use the fqdn in the EHLO/HELO verb, and + # if that can't be calculated, that we should use a domain literal + # instead (essentially an encoded IP address like [A.B.C.D]). + fqdn = socket.getfqdn() + if '.' in fqdn: + self.local_hostname = fqdn + else: + # We can't find an fqdn hostname, so use a domain literal + addr = '127.0.0.1' + try: + addr = socket.gethostbyname(socket.gethostname()) + except socket.gaierror: + pass + self.local_hostname = '[%s]' % addr + + def __enter__(self): + return self + + def __exit__(self, *args): + try: + code, message = self.docmd("QUIT") + if code != 221: + raise SMTPResponseException(code, message) + except SMTPServerDisconnected: + pass + finally: + self.close() + + def set_debuglevel(self, debuglevel): + """Set the debug output level. + + A non-false value results in debug messages for connection and for all + messages sent to and received from the server. + + """ + self.debuglevel = debuglevel + + def _print_debug(self, *args): + if self.debuglevel > 1: + print(datetime.datetime.now().time(), *args, file=sys.stderr) + else: + print(*args, file=sys.stderr) + + def _get_socket(self, host, port, timeout): + # This makes it simpler for SMTP_SSL to use the SMTP connect code + # and just alter the socket connection bit. + if timeout is not None and not timeout: + raise ValueError('Non-blocking socket (timeout=0) is not supported') + if self.debuglevel > 0: + self._print_debug('connect: to', (host, port), self.source_address) + return socket.create_connection((host, port), timeout, + self.source_address) + + def connect(self, host='localhost', port=0, source_address=None): + """Connect to a host on a given port. + + If the hostname ends with a colon (`:') followed by a number, and + there is no port specified, that suffix will be stripped off and the + number interpreted as the port number to use. + + Note: This method is automatically invoked by __init__, if a host is + specified during instantiation. + + """ + + if source_address: + self.source_address = source_address + + if not port and (host.find(':') == host.rfind(':')): + i = host.rfind(':') + if i >= 0: + host, port = host[:i], host[i + 1:] + try: + port = int(port) + except ValueError: + raise OSError("nonnumeric port") + if not port: + port = self.default_port + sys.audit("smtplib.connect", self, host, port) + self.sock = self._get_socket(host, port, self.timeout) + self.file = None + (code, msg) = self.getreply() + if self.debuglevel > 0: + self._print_debug('connect:', repr(msg)) + return (code, msg) + + def send(self, s): + """Send `s' to the server.""" + if self.debuglevel > 0: + self._print_debug('send:', repr(s)) + if self.sock: + if isinstance(s, str): + # send is used by the 'data' command, where command_encoding + # should not be used, but 'data' needs to convert the string to + # binary itself anyway, so that's not a problem. + s = s.encode(self.command_encoding) + sys.audit("smtplib.send", self, s) + try: + self.sock.sendall(s) + except OSError: + self.close() + raise SMTPServerDisconnected('Server not connected') + else: + raise SMTPServerDisconnected('please run connect() first') + + def putcmd(self, cmd, args=""): + """Send a command to the server.""" + if args == "": + s = cmd + else: + s = f'{cmd} {args}' + if '\r' in s or '\n' in s: + s = s.replace('\n', '\\n').replace('\r', '\\r') + raise ValueError( + f'command and arguments contain prohibited newline characters: {s}' + ) + self.send(f'{s}{CRLF}') + + def getreply(self): + """Get a reply from the server. + + Returns a tuple consisting of: + + - server response code (e.g. '250', or such, if all goes well) + Note: returns -1 if it can't read response code. + + - server response string corresponding to response code (multiline + responses are converted to a single, multiline string). + + Raises SMTPServerDisconnected if end-of-file is reached. + """ + resp = [] + if self.file is None: + self.file = self.sock.makefile('rb') + while 1: + try: + line = self.file.readline(_MAXLINE + 1) + except OSError as e: + self.close() + raise SMTPServerDisconnected("Connection unexpectedly closed: " + + str(e)) + if not line: + self.close() + raise SMTPServerDisconnected("Connection unexpectedly closed") + if self.debuglevel > 0: + self._print_debug('reply:', repr(line)) + if len(line) > _MAXLINE: + self.close() + raise SMTPResponseException(500, "Line too long.") + resp.append(line[4:].strip(b' \t\r\n')) + code = line[:3] + # Check that the error code is syntactically correct. + # Don't attempt to read a continuation line if it is broken. + try: + errcode = int(code) + except ValueError: + errcode = -1 + break + # Check if multiline response. + if line[3:4] != b"-": + break + + errmsg = b"\n".join(resp) + if self.debuglevel > 0: + self._print_debug('reply: retcode (%s); Msg: %a' % (errcode, errmsg)) + return errcode, errmsg + + def docmd(self, cmd, args=""): + """Send a command, and return its response code.""" + self.putcmd(cmd, args) + return self.getreply() + + # std smtp commands + def helo(self, name=''): + """SMTP 'helo' command. + Hostname to send for this command defaults to the FQDN of the local + host. + """ + self.putcmd("helo", name or self.local_hostname) + (code, msg) = self.getreply() + self.helo_resp = msg + return (code, msg) + + def ehlo(self, name=''): + """ SMTP 'ehlo' command. + Hostname to send for this command defaults to the FQDN of the local + host. + """ + self.esmtp_features = {} + self.putcmd(self.ehlo_msg, name or self.local_hostname) + (code, msg) = self.getreply() + # According to RFC1869 some (badly written) + # MTA's will disconnect on an ehlo. Toss an exception if + # that happens -ddm + if code == -1 and len(msg) == 0: + self.close() + raise SMTPServerDisconnected("Server not connected") + self.ehlo_resp = msg + if code != 250: + return (code, msg) + self.does_esmtp = True + #parse the ehlo response -ddm + assert isinstance(self.ehlo_resp, bytes), repr(self.ehlo_resp) + resp = self.ehlo_resp.decode("latin-1").split('\n') + del resp[0] + for each in resp: + # To be able to communicate with as many SMTP servers as possible, + # we have to take the old-style auth advertisement into account, + # because: + # 1) Else our SMTP feature parser gets confused. + # 2) There are some servers that only advertise the auth methods we + # support using the old style. + auth_match = OLDSTYLE_AUTH.match(each) + if auth_match: + # This doesn't remove duplicates, but that's no problem + self.esmtp_features["auth"] = self.esmtp_features.get("auth", "") \ + + " " + auth_match.groups(0)[0] + continue + + # RFC 1869 requires a space between ehlo keyword and parameters. + # It's actually stricter, in that only spaces are allowed between + # parameters, but were not going to check for that here. Note + # that the space isn't present if there are no parameters. + m = re.match(r'(?P[A-Za-z0-9][A-Za-z0-9\-]*) ?', each) + if m: + feature = m.group("feature").lower() + params = m.string[m.end("feature"):].strip() + if feature == "auth": + self.esmtp_features[feature] = self.esmtp_features.get(feature, "") \ + + " " + params + else: + self.esmtp_features[feature] = params + return (code, msg) + + def has_extn(self, opt): + """Does the server support a given SMTP service extension?""" + return opt.lower() in self.esmtp_features + + def help(self, args=''): + """SMTP 'help' command. + Returns help text from server.""" + self.putcmd("help", args) + return self.getreply()[1] + + def rset(self): + """SMTP 'rset' command -- resets session.""" + self.command_encoding = 'ascii' + return self.docmd("rset") + + def _rset(self): + """Internal 'rset' command which ignores any SMTPServerDisconnected error. + + Used internally in the library, since the server disconnected error + should appear to the application when the *next* command is issued, if + we are doing an internal "safety" reset. + """ + try: + self.rset() + except SMTPServerDisconnected: + pass + + def noop(self): + """SMTP 'noop' command -- doesn't do anything :>""" + return self.docmd("noop") + + def mail(self, sender, options=()): + """SMTP 'mail' command -- begins mail xfer session. + + This method may raise the following exceptions: + + SMTPNotSupportedError The options parameter includes 'SMTPUTF8' + but the SMTPUTF8 extension is not supported by + the server. + """ + optionlist = '' + if options and self.does_esmtp: + if any(x.lower()=='smtputf8' for x in options): + if self.has_extn('smtputf8'): + self.command_encoding = 'utf-8' + else: + raise SMTPNotSupportedError( + 'SMTPUTF8 not supported by server') + optionlist = ' ' + ' '.join(options) + self.putcmd("mail", "FROM:%s%s" % (quoteaddr(sender), optionlist)) + return self.getreply() + + def rcpt(self, recip, options=()): + """SMTP 'rcpt' command -- indicates 1 recipient for this mail.""" + optionlist = '' + if options and self.does_esmtp: + optionlist = ' ' + ' '.join(options) + self.putcmd("rcpt", "TO:%s%s" % (quoteaddr(recip), optionlist)) + return self.getreply() + + def data(self, msg): + """SMTP 'DATA' command -- sends message data to server. + + Automatically quotes lines beginning with a period per rfc821. + Raises SMTPDataError if there is an unexpected reply to the + DATA command; the return value from this method is the final + response code received when the all data is sent. If msg + is a string, lone '\\r' and '\\n' characters are converted to + '\\r\\n' characters. If msg is bytes, it is transmitted as is. + """ + self.putcmd("data") + (code, repl) = self.getreply() + if self.debuglevel > 0: + self._print_debug('data:', (code, repl)) + if code != 354: + raise SMTPDataError(code, repl) + else: + if isinstance(msg, str): + msg = _fix_eols(msg).encode('ascii') + q = _quote_periods(msg) + if q[-2:] != bCRLF: + q = q + bCRLF + q = q + b"." + bCRLF + self.send(q) + (code, msg) = self.getreply() + if self.debuglevel > 0: + self._print_debug('data:', (code, msg)) + return (code, msg) + + def verify(self, address): + """SMTP 'verify' command -- checks for address validity.""" + self.putcmd("vrfy", _addr_only(address)) + return self.getreply() + # a.k.a. + vrfy = verify + + def expn(self, address): + """SMTP 'expn' command -- expands a mailing list.""" + self.putcmd("expn", _addr_only(address)) + return self.getreply() + + # some useful methods + + def ehlo_or_helo_if_needed(self): + """Call self.ehlo() and/or self.helo() if needed. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + """ + if self.helo_resp is None and self.ehlo_resp is None: + if not (200 <= self.ehlo()[0] <= 299): + (code, resp) = self.helo() + if not (200 <= code <= 299): + raise SMTPHeloError(code, resp) + + def auth(self, mechanism, authobject, *, initial_response_ok=True): + """Authentication command - requires response processing. + + 'mechanism' specifies which authentication mechanism is to + be used - the valid values are those listed in the 'auth' + element of 'esmtp_features'. + + 'authobject' must be a callable object taking a single argument: + + data = authobject(challenge) + + It will be called to process the server's challenge response; the + challenge argument it is passed will be a bytes. It should return + an ASCII string that will be base64 encoded and sent to the server. + + Keyword arguments: + - initial_response_ok: Allow sending the RFC 4954 initial-response + to the AUTH command, if the authentication methods supports it. + """ + # RFC 4954 allows auth methods to provide an initial response. Not all + # methods support it. By definition, if they return something other + # than None when challenge is None, then they do. See issue #15014. + mechanism = mechanism.upper() + initial_response = (authobject() if initial_response_ok else None) + if initial_response is not None: + response = encode_base64(initial_response.encode('ascii'), eol='') + (code, resp) = self.docmd("AUTH", mechanism + " " + response) + self._auth_challenge_count = 1 + else: + (code, resp) = self.docmd("AUTH", mechanism) + self._auth_challenge_count = 0 + # If server responds with a challenge, send the response. + while code == 334: + self._auth_challenge_count += 1 + challenge = base64.decodebytes(resp) + response = encode_base64( + authobject(challenge).encode('ascii'), eol='') + (code, resp) = self.docmd(response) + # If server keeps sending challenges, something is wrong. + if self._auth_challenge_count > _MAXCHALLENGE: + raise SMTPException( + "Server AUTH mechanism infinite loop. Last response: " + + repr((code, resp)) + ) + if code in (235, 503): + return (code, resp) + raise SMTPAuthenticationError(code, resp) + + def auth_cram_md5(self, challenge=None): + """ Authobject to use with CRAM-MD5 authentication. Requires self.user + and self.password to be set.""" + # CRAM-MD5 does not support initial-response. + if challenge is None: + return None + return self.user + " " + hmac.HMAC( + self.password.encode('ascii'), challenge, 'md5').hexdigest() + + def auth_plain(self, challenge=None): + """ Authobject to use with PLAIN authentication. Requires self.user and + self.password to be set.""" + return "\0%s\0%s" % (self.user, self.password) + + def auth_login(self, challenge=None): + """ Authobject to use with LOGIN authentication. Requires self.user and + self.password to be set.""" + if challenge is None or self._auth_challenge_count < 2: + return self.user + else: + return self.password + + def login(self, user, password, *, initial_response_ok=True): + """Log in on an SMTP server that requires authentication. + + The arguments are: + - user: The user name to authenticate with. + - password: The password for the authentication. + + Keyword arguments: + - initial_response_ok: Allow sending the RFC 4954 initial-response + to the AUTH command, if the authentication methods supports it. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. + + This method will return normally if the authentication was successful. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + SMTPAuthenticationError The server didn't accept the username/ + password combination. + SMTPNotSupportedError The AUTH command is not supported by the + server. + SMTPException No suitable authentication method was + found. + """ + + self.ehlo_or_helo_if_needed() + if not self.has_extn("auth"): + raise SMTPNotSupportedError( + "SMTP AUTH extension not supported by server.") + + # Authentication methods the server claims to support + advertised_authlist = self.esmtp_features["auth"].split() + + # Authentication methods we can handle in our preferred order: + preferred_auths = ['CRAM-MD5', 'PLAIN', 'LOGIN'] + + # We try the supported authentications in our preferred order, if + # the server supports them. + authlist = [auth for auth in preferred_auths + if auth in advertised_authlist] + if not authlist: + raise SMTPException("No suitable authentication method found.") + + # Some servers advertise authentication methods they don't really + # support, so if authentication fails, we continue until we've tried + # all methods. + self.user, self.password = user, password + for authmethod in authlist: + method_name = 'auth_' + authmethod.lower().replace('-', '_') + try: + (code, resp) = self.auth( + authmethod, getattr(self, method_name), + initial_response_ok=initial_response_ok) + # 235 == 'Authentication successful' + # 503 == 'Error: already authenticated' + if code in (235, 503): + return (code, resp) + except SMTPAuthenticationError as e: + last_exception = e + + # We could not login successfully. Return result of last attempt. + raise last_exception + + def starttls(self, *, context=None): + """Puts the connection to the SMTP server into TLS mode. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. + + If the server supports TLS, this will encrypt the rest of the SMTP + session. If you provide the context parameter, + the identity of the SMTP server and client can be checked. This, + however, depends on whether the socket module really checks the + certificates. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + """ + self.ehlo_or_helo_if_needed() + if not self.has_extn("starttls"): + raise SMTPNotSupportedError( + "STARTTLS extension not supported by server.") + (resp, reply) = self.docmd("STARTTLS") + if resp == 220: + if not _have_ssl: + raise RuntimeError("No SSL support included in this Python") + if context is None: + context = ssl._create_stdlib_context() + self.sock = context.wrap_socket(self.sock, + server_hostname=self._host) + self.file = None + # RFC 3207: + # The client MUST discard any knowledge obtained from + # the server, such as the list of SMTP service extensions, + # which was not obtained from the TLS negotiation itself. + self.helo_resp = None + self.ehlo_resp = None + self.esmtp_features = {} + self.does_esmtp = False + else: + # RFC 3207: + # 501 Syntax error (no parameters allowed) + # 454 TLS not available due to temporary reason + raise SMTPResponseException(resp, reply) + return (resp, reply) + + def sendmail(self, from_addr, to_addrs, msg, mail_options=(), + rcpt_options=()): + """This command performs an entire mail transaction. + + The arguments are: + - from_addr : The address sending this mail. + - to_addrs : A list of addresses to send this mail to. A bare + string will be treated as a list with 1 address. + - msg : The message to send. + - mail_options : List of ESMTP options (such as 8bitmime) for the + mail command. + - rcpt_options : List of ESMTP options (such as DSN commands) for + all the rcpt commands. + + msg may be a string containing characters in the ASCII range, or a byte + string. A string is encoded to bytes using the ascii codec, and lone + \\r and \\n characters are converted to \\r\\n characters. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. If the server does ESMTP, message size + and each of the specified options will be passed to it. If EHLO + fails, HELO will be tried and ESMTP options suppressed. + + This method will return normally if the mail is accepted for at least + one recipient. It returns a dictionary, with one entry for each + recipient that was refused. Each entry contains a tuple of the SMTP + error code and the accompanying error message sent by the server. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + SMTPRecipientsRefused The server rejected ALL recipients + (no mail was sent). + SMTPSenderRefused The server didn't accept the from_addr. + SMTPDataError The server replied with an unexpected + error code (other than a refusal of + a recipient). + SMTPNotSupportedError The mail_options parameter includes 'SMTPUTF8' + but the SMTPUTF8 extension is not supported by + the server. + + Note: the connection will be open even after an exception is raised. + + Example: + + >>> import smtplib + >>> s=smtplib.SMTP("localhost") + >>> tolist=["one@one.org","two@two.org","three@three.org","four@four.org"] + >>> msg = '''\\ + ... From: Me@my.org + ... Subject: testin'... + ... + ... This is a test ''' + >>> s.sendmail("me@my.org",tolist,msg) + { "three@three.org" : ( 550 ,"User unknown" ) } + >>> s.quit() + + In the above example, the message was accepted for delivery to three + of the four addresses, and one was rejected, with the error code + 550. If all addresses are accepted, then the method will return an + empty dictionary. + + """ + self.ehlo_or_helo_if_needed() + esmtp_opts = [] + if isinstance(msg, str): + msg = _fix_eols(msg).encode('ascii') + if self.does_esmtp: + if self.has_extn('size'): + esmtp_opts.append("size=%d" % len(msg)) + for option in mail_options: + esmtp_opts.append(option) + (code, resp) = self.mail(from_addr, esmtp_opts) + if code != 250: + if code == 421: + self.close() + else: + self._rset() + raise SMTPSenderRefused(code, resp, from_addr) + senderrs = {} + if isinstance(to_addrs, str): + to_addrs = [to_addrs] + for each in to_addrs: + (code, resp) = self.rcpt(each, rcpt_options) + if (code != 250) and (code != 251): + senderrs[each] = (code, resp) + if code == 421: + self.close() + raise SMTPRecipientsRefused(senderrs) + if len(senderrs) == len(to_addrs): + # the server refused all our recipients + self._rset() + raise SMTPRecipientsRefused(senderrs) + (code, resp) = self.data(msg) + if code != 250: + if code == 421: + self.close() + else: + self._rset() + raise SMTPDataError(code, resp) + #if we got here then somebody got our mail + return senderrs + + def send_message(self, msg, from_addr=None, to_addrs=None, + mail_options=(), rcpt_options=()): + """Converts message to a bytestring and passes it to sendmail. + + The arguments are as for sendmail, except that msg is an + email.message.Message object. If from_addr is None or to_addrs is + None, these arguments are taken from the headers of the Message as + described in RFC 2822 (a ValueError is raised if there is more than + one set of 'Resent-' headers). Regardless of the values of from_addr and + to_addr, any Bcc field (or Resent-Bcc field, when the Message is a + resent) of the Message object won't be transmitted. The Message + object is then serialized using email.generator.BytesGenerator and + sendmail is called to transmit the message. If the sender or any of + the recipient addresses contain non-ASCII and the server advertises the + SMTPUTF8 capability, the policy is cloned with utf8 set to True for the + serialization, and SMTPUTF8 and BODY=8BITMIME are asserted on the send. + If the server does not support SMTPUTF8, an SMTPNotSupported error is + raised. Otherwise the generator is called without modifying the + policy. + + """ + # 'Resent-Date' is a mandatory field if the Message is resent (RFC 2822 + # Section 3.6.6). In such a case, we use the 'Resent-*' fields. However, + # if there is more than one 'Resent-' block there's no way to + # unambiguously determine which one is the most recent in all cases, + # so rather than guess we raise a ValueError in that case. + # + # TODO implement heuristics to guess the correct Resent-* block with an + # option allowing the user to enable the heuristics. (It should be + # possible to guess correctly almost all of the time.) + + self.ehlo_or_helo_if_needed() + resent = msg.get_all('Resent-Date') + if resent is None: + header_prefix = '' + elif len(resent) == 1: + header_prefix = 'Resent-' + else: + raise ValueError("message has more than one 'Resent-' header block") + if from_addr is None: + # Prefer the sender field per RFC 2822:3.6.2. + from_addr = (msg[header_prefix + 'Sender'] + if (header_prefix + 'Sender') in msg + else msg[header_prefix + 'From']) + from_addr = email.utils.getaddresses([from_addr])[0][1] + if to_addrs is None: + addr_fields = [f for f in (msg[header_prefix + 'To'], + msg[header_prefix + 'Bcc'], + msg[header_prefix + 'Cc']) + if f is not None] + to_addrs = [a[1] for a in email.utils.getaddresses(addr_fields)] + # Make a local copy so we can delete the bcc headers. + msg_copy = copy.copy(msg) + del msg_copy['Bcc'] + del msg_copy['Resent-Bcc'] + international = False + try: + ''.join([from_addr, *to_addrs]).encode('ascii') + except UnicodeEncodeError: + if not self.has_extn('smtputf8'): + raise SMTPNotSupportedError( + "One or more source or delivery addresses require" + " internationalized email support, but the server" + " does not advertise the required SMTPUTF8 capability") + international = True + with io.BytesIO() as bytesmsg: + if international: + g = email.generator.BytesGenerator( + bytesmsg, policy=msg.policy.clone(utf8=True)) + mail_options = (*mail_options, 'SMTPUTF8', 'BODY=8BITMIME') + else: + g = email.generator.BytesGenerator(bytesmsg) + g.flatten(msg_copy, linesep='\r\n') + flatmsg = bytesmsg.getvalue() + return self.sendmail(from_addr, to_addrs, flatmsg, mail_options, + rcpt_options) + + def close(self): + """Close the connection to the SMTP server.""" + try: + file = self.file + self.file = None + if file: + file.close() + finally: + sock = self.sock + self.sock = None + if sock: + sock.close() + + def quit(self): + """Terminate the SMTP session.""" + res = self.docmd("quit") + # A new EHLO is required after reconnecting with connect() + self.ehlo_resp = self.helo_resp = None + self.esmtp_features = {} + self.does_esmtp = False + self.close() + return res + +if _have_ssl: + + class SMTP_SSL(SMTP): + """ This is a subclass derived from SMTP that connects over an SSL + encrypted socket (to use this class you need a socket module that was + compiled with SSL support). If host is not specified, '' (the local + host) is used. If port is omitted, the standard SMTP-over-SSL port + (465) is used. local_hostname and source_address have the same meaning + as they do in the SMTP class. context also optional, can contain a + SSLContext. + + """ + + default_port = SMTP_SSL_PORT + + def __init__(self, host='', port=0, local_hostname=None, + *, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, context=None): + if context is None: + context = ssl._create_stdlib_context() + self.context = context + SMTP.__init__(self, host, port, local_hostname, timeout, + source_address) + + def _get_socket(self, host, port, timeout): + if self.debuglevel > 0: + self._print_debug('connect:', (host, port)) + new_socket = super()._get_socket(host, port, timeout) + new_socket = self.context.wrap_socket(new_socket, + server_hostname=self._host) + return new_socket + + __all__.append("SMTP_SSL") + +# +# LMTP extension +# +LMTP_PORT = 2003 + +class LMTP(SMTP): + """LMTP - Local Mail Transfer Protocol + + The LMTP protocol, which is very similar to ESMTP, is heavily based + on the standard SMTP client. It's common to use Unix sockets for + LMTP, so our connect() method must support that as well as a regular + host:port server. local_hostname and source_address have the same + meaning as they do in the SMTP class. To specify a Unix socket, + you must use an absolute path as the host, starting with a '/'. + + Authentication is supported, using the regular SMTP mechanism. When + using a Unix socket, LMTP generally don't support or require any + authentication, but your mileage might vary.""" + + ehlo_msg = "lhlo" + + def __init__(self, host='', port=LMTP_PORT, local_hostname=None, + source_address=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): + """Initialize a new instance.""" + super().__init__(host, port, local_hostname=local_hostname, + source_address=source_address, timeout=timeout) + + def connect(self, host='localhost', port=0, source_address=None): + """Connect to the LMTP daemon, on either a Unix or a TCP socket.""" + if host[0] != '/': + return super().connect(host, port, source_address=source_address) + + if self.timeout is not None and not self.timeout: + raise ValueError('Non-blocking socket (timeout=0) is not supported') + + # Handle Unix-domain sockets. + try: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + if self.timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + self.sock.settimeout(self.timeout) + self.file = None + self.sock.connect(host) + except OSError: + if self.debuglevel > 0: + self._print_debug('connect fail:', host) + if self.sock: + self.sock.close() + self.sock = None + raise + (code, msg) = self.getreply() + if self.debuglevel > 0: + self._print_debug('connect:', msg) + return (code, msg) + + +# Test the sendmail method, which tests most of the others. +# Note: This always sends to localhost. +if __name__ == '__main__': + def prompt(prompt): + sys.stdout.write(prompt + ": ") + sys.stdout.flush() + return sys.stdin.readline().strip() + + fromaddr = prompt("From") + toaddrs = prompt("To").split(',') + print("Enter message, end with ^D:") + msg = '' + while line := sys.stdin.readline(): + msg = msg + line + print("Message length is %d" % len(msg)) + + server = SMTP('localhost') + server.set_debuglevel(1) + server.sendmail(fromaddr, toaddrs, msg) + server.quit() diff --git a/Lib/test/ieee754.txt b/Lib/test/ieee754.txt new file mode 100644 index 0000000000..3e986cdb10 --- /dev/null +++ b/Lib/test/ieee754.txt @@ -0,0 +1,183 @@ +====================================== +Python IEEE 754 floating point support +====================================== + +>>> from sys import float_info as FI +>>> from math import * +>>> PI = pi +>>> E = e + +You must never compare two floats with == because you are not going to get +what you expect. We treat two floats as equal if the difference between them +is small than epsilon. +>>> EPS = 1E-15 +>>> def equal(x, y): +... """Almost equal helper for floats""" +... return abs(x - y) < EPS + + +NaNs and INFs +============= + +In Python 2.6 and newer NaNs (not a number) and infinity can be constructed +from the strings 'inf' and 'nan'. + +>>> INF = float('inf') +>>> NINF = float('-inf') +>>> NAN = float('nan') + +>>> INF +inf +>>> NINF +-inf +>>> NAN +nan + +The math module's ``isnan`` and ``isinf`` functions can be used to detect INF +and NAN: +>>> isinf(INF), isinf(NINF), isnan(NAN) +(True, True, True) +>>> INF == -NINF +True + +Infinity +-------- + +Ambiguous operations like ``0 * inf`` or ``inf - inf`` result in NaN. +>>> INF * 0 +nan +>>> INF - INF +nan +>>> INF / INF +nan + +However unambiguous operations with inf return inf: +>>> INF * INF +inf +>>> 1.5 * INF +inf +>>> 0.5 * INF +inf +>>> INF / 1000 +inf + +Not a Number +------------ + +NaNs are never equal to another number, even itself +>>> NAN == NAN +False +>>> NAN < 0 +False +>>> NAN >= 0 +False + +All operations involving a NaN return a NaN except for nan**0 and 1**nan. +>>> 1 + NAN +nan +>>> 1 * NAN +nan +>>> 0 * NAN +nan +>>> 1 ** NAN +1.0 +>>> NAN ** 0 +1.0 +>>> 0 ** NAN +nan +>>> (1.0 + FI.epsilon) * NAN +nan + +Misc Functions +============== + +The power of 1 raised to x is always 1.0, even for special values like 0, +infinity and NaN. + +>>> pow(1, 0) +1.0 +>>> pow(1, INF) +1.0 +>>> pow(1, -INF) +1.0 +>>> pow(1, NAN) +1.0 + +The power of 0 raised to x is defined as 0, if x is positive. Negative +finite values are a domain error or zero division error and NaN result in a +silent NaN. + +>>> pow(0, 0) +1.0 +>>> pow(0, INF) +0.0 +>>> pow(0, -INF) +inf +>>> 0 ** -1 +Traceback (most recent call last): +... +ZeroDivisionError: 0.0 cannot be raised to a negative power +>>> pow(0, NAN) +nan + + +Trigonometric Functions +======================= + +>>> sin(INF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> sin(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> sin(NAN) +nan +>>> cos(INF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> cos(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> cos(NAN) +nan +>>> tan(INF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> tan(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> tan(NAN) +nan + +Neither pi nor tan are exact, but you can assume that tan(pi/2) is a large value +and tan(pi) is a very small value: +>>> tan(PI/2) > 1E10 +True +>>> -tan(-PI/2) > 1E10 +True +>>> tan(PI) < 1E-15 +True + +>>> asin(NAN), acos(NAN), atan(NAN) +(nan, nan, nan) +>>> asin(INF), asin(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> acos(INF), acos(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> equal(atan(INF), PI/2), equal(atan(NINF), -PI/2) +(True, True) + + +Hyberbolic Functions +==================== + diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index a1a6bd8e73..1efe5bddb1 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -400,33 +400,37 @@ def skip_if_buildbot(reason=None): isbuildbot = False return unittest.skipIf(isbuildbot, reason) -def check_sanitizer(*, address=False, memory=False, ub=False): +def check_sanitizer(*, address=False, memory=False, ub=False, thread=False): """Returns True if Python is compiled with sanitizer support""" - if not (address or memory or ub): - raise ValueError('At least one of address, memory, or ub must be True') + if not (address or memory or ub or thread): + raise ValueError('At least one of address, memory, ub or thread must be True') - _cflags = sysconfig.get_config_var('CFLAGS') or '' - _config_args = sysconfig.get_config_var('CONFIG_ARGS') or '' + cflags = sysconfig.get_config_var('CFLAGS') or '' + config_args = sysconfig.get_config_var('CONFIG_ARGS') or '' memory_sanitizer = ( - '-fsanitize=memory' in _cflags or - '--with-memory-sanitizer' in _config_args + '-fsanitize=memory' in cflags or + '--with-memory-sanitizer' in config_args ) address_sanitizer = ( - '-fsanitize=address' in _cflags or - '--with-address-sanitizer' in _config_args + '-fsanitize=address' in cflags or + '--with-address-sanitizer' in config_args ) ub_sanitizer = ( - '-fsanitize=undefined' in _cflags or - '--with-undefined-behavior-sanitizer' in _config_args + '-fsanitize=undefined' in cflags or + '--with-undefined-behavior-sanitizer' in config_args + ) + thread_sanitizer = ( + '-fsanitize=thread' in cflags or + '--with-thread-sanitizer' in config_args ) return ( - (memory and memory_sanitizer) or - (address and address_sanitizer) or - (ub and ub_sanitizer) + (memory and memory_sanitizer) or + (address and address_sanitizer) or + (ub and ub_sanitizer) or + (thread and thread_sanitizer) ) - def skip_if_sanitizer(reason=None, *, address=False, memory=False, ub=False, thread=False): """Decorator raising SkipTest if running with a sanitizer active.""" if not reason: @@ -2550,3 +2554,21 @@ def adjust_int_max_str_digits(max_digits): #Windows doesn't have os.uname() but it doesn't support s390x. skip_on_s390x = unittest.skipIf(hasattr(os, 'uname') and os.uname().machine == 's390x', 'skipped on s390x') +HAVE_ASAN_FORK_BUG = check_sanitizer(address=True) + +# From python 3.12.8 +class BrokenIter: + def __init__(self, init_raises=False, next_raises=False, iter_raises=False): + if init_raises: + 1/0 + self.next_raises = next_raises + self.iter_raises = iter_raises + + def __next__(self): + if self.next_raises: + 1/0 + + def __iter__(self): + if self.iter_raises: + 1/0 + return self diff --git a/Lib/test/support/smtpd.py b/Lib/test/support/smtpd.py new file mode 100644 index 0000000000..ec4e7d2f4c --- /dev/null +++ b/Lib/test/support/smtpd.py @@ -0,0 +1,873 @@ +#! /usr/bin/env python3 +"""An RFC 5321 smtp proxy with optional RFC 1870 and RFC 6531 extensions. + +Usage: %(program)s [options] [localhost:localport [remotehost:remoteport]] + +Options: + + --nosetuid + -n + This program generally tries to setuid `nobody', unless this flag is + set. The setuid call will fail if this program is not run as root (in + which case, use this flag). + + --version + -V + Print the version number and exit. + + --class classname + -c classname + Use `classname' as the concrete SMTP proxy class. Uses `PureProxy' by + default. + + --size limit + -s limit + Restrict the total size of the incoming message to "limit" number of + bytes via the RFC 1870 SIZE extension. Defaults to 33554432 bytes. + + --smtputf8 + -u + Enable the SMTPUTF8 extension and behave as an RFC 6531 smtp proxy. + + --debug + -d + Turn on debugging prints. + + --help + -h + Print this message and exit. + +Version: %(__version__)s + +If localhost is not given then `localhost' is used, and if localport is not +given then 8025 is used. If remotehost is not given then `localhost' is used, +and if remoteport is not given, then 25 is used. +""" + +# Overview: +# +# This file implements the minimal SMTP protocol as defined in RFC 5321. It +# has a hierarchy of classes which implement the backend functionality for the +# smtpd. A number of classes are provided: +# +# SMTPServer - the base class for the backend. Raises NotImplementedError +# if you try to use it. +# +# DebuggingServer - simply prints each message it receives on stdout. +# +# PureProxy - Proxies all messages to a real smtpd which does final +# delivery. One known problem with this class is that it doesn't handle +# SMTP errors from the backend server at all. This should be fixed +# (contributions are welcome!). +# +# +# Author: Barry Warsaw +# +# TODO: +# +# - support mailbox delivery +# - alias files +# - Handle more ESMTP extensions +# - handle error codes from the backend smtpd + +import sys +import os +import errno +import getopt +import time +import socket +import collections +from test.support import asyncore, asynchat +from warnings import warn +from email._header_value_parser import get_addr_spec, get_angle_addr + +__all__ = [ + "SMTPChannel", "SMTPServer", "DebuggingServer", "PureProxy", +] + +program = sys.argv[0] +__version__ = 'Python SMTP proxy version 0.3' + + +class Devnull: + def write(self, msg): pass + def flush(self): pass + + +DEBUGSTREAM = Devnull() +NEWLINE = '\n' +COMMASPACE = ', ' +DATA_SIZE_DEFAULT = 33554432 + + +def usage(code, msg=''): + print(__doc__ % globals(), file=sys.stderr) + if msg: + print(msg, file=sys.stderr) + sys.exit(code) + + +class SMTPChannel(asynchat.async_chat): + COMMAND = 0 + DATA = 1 + + command_size_limit = 512 + command_size_limits = collections.defaultdict(lambda x=command_size_limit: x) + + @property + def max_command_size_limit(self): + try: + return max(self.command_size_limits.values()) + except ValueError: + return self.command_size_limit + + def __init__(self, server, conn, addr, data_size_limit=DATA_SIZE_DEFAULT, + map=None, enable_SMTPUTF8=False, decode_data=False): + asynchat.async_chat.__init__(self, conn, map=map) + self.smtp_server = server + self.conn = conn + self.addr = addr + self.data_size_limit = data_size_limit + self.enable_SMTPUTF8 = enable_SMTPUTF8 + self._decode_data = decode_data + if enable_SMTPUTF8 and decode_data: + raise ValueError("decode_data and enable_SMTPUTF8 cannot" + " be set to True at the same time") + if decode_data: + self._emptystring = '' + self._linesep = '\r\n' + self._dotsep = '.' + self._newline = NEWLINE + else: + self._emptystring = b'' + self._linesep = b'\r\n' + self._dotsep = ord(b'.') + self._newline = b'\n' + self._set_rset_state() + self.seen_greeting = '' + self.extended_smtp = False + self.command_size_limits.clear() + self.fqdn = socket.getfqdn() + try: + self.peer = conn.getpeername() + except OSError as err: + # a race condition may occur if the other end is closing + # before we can get the peername + self.close() + if err.errno != errno.ENOTCONN: + raise + return + print('Peer:', repr(self.peer), file=DEBUGSTREAM) + self.push('220 %s %s' % (self.fqdn, __version__)) + + def _set_post_data_state(self): + """Reset state variables to their post-DATA state.""" + self.smtp_state = self.COMMAND + self.mailfrom = None + self.rcpttos = [] + self.require_SMTPUTF8 = False + self.num_bytes = 0 + self.set_terminator(b'\r\n') + + def _set_rset_state(self): + """Reset all state variables except the greeting.""" + self._set_post_data_state() + self.received_data = '' + self.received_lines = [] + + + # properties for backwards-compatibility + @property + def __server(self): + warn("Access to __server attribute on SMTPChannel is deprecated, " + "use 'smtp_server' instead", DeprecationWarning, 2) + return self.smtp_server + @__server.setter + def __server(self, value): + warn("Setting __server attribute on SMTPChannel is deprecated, " + "set 'smtp_server' instead", DeprecationWarning, 2) + self.smtp_server = value + + @property + def __line(self): + warn("Access to __line attribute on SMTPChannel is deprecated, " + "use 'received_lines' instead", DeprecationWarning, 2) + return self.received_lines + @__line.setter + def __line(self, value): + warn("Setting __line attribute on SMTPChannel is deprecated, " + "set 'received_lines' instead", DeprecationWarning, 2) + self.received_lines = value + + @property + def __state(self): + warn("Access to __state attribute on SMTPChannel is deprecated, " + "use 'smtp_state' instead", DeprecationWarning, 2) + return self.smtp_state + @__state.setter + def __state(self, value): + warn("Setting __state attribute on SMTPChannel is deprecated, " + "set 'smtp_state' instead", DeprecationWarning, 2) + self.smtp_state = value + + @property + def __greeting(self): + warn("Access to __greeting attribute on SMTPChannel is deprecated, " + "use 'seen_greeting' instead", DeprecationWarning, 2) + return self.seen_greeting + @__greeting.setter + def __greeting(self, value): + warn("Setting __greeting attribute on SMTPChannel is deprecated, " + "set 'seen_greeting' instead", DeprecationWarning, 2) + self.seen_greeting = value + + @property + def __mailfrom(self): + warn("Access to __mailfrom attribute on SMTPChannel is deprecated, " + "use 'mailfrom' instead", DeprecationWarning, 2) + return self.mailfrom + @__mailfrom.setter + def __mailfrom(self, value): + warn("Setting __mailfrom attribute on SMTPChannel is deprecated, " + "set 'mailfrom' instead", DeprecationWarning, 2) + self.mailfrom = value + + @property + def __rcpttos(self): + warn("Access to __rcpttos attribute on SMTPChannel is deprecated, " + "use 'rcpttos' instead", DeprecationWarning, 2) + return self.rcpttos + @__rcpttos.setter + def __rcpttos(self, value): + warn("Setting __rcpttos attribute on SMTPChannel is deprecated, " + "set 'rcpttos' instead", DeprecationWarning, 2) + self.rcpttos = value + + @property + def __data(self): + warn("Access to __data attribute on SMTPChannel is deprecated, " + "use 'received_data' instead", DeprecationWarning, 2) + return self.received_data + @__data.setter + def __data(self, value): + warn("Setting __data attribute on SMTPChannel is deprecated, " + "set 'received_data' instead", DeprecationWarning, 2) + self.received_data = value + + @property + def __fqdn(self): + warn("Access to __fqdn attribute on SMTPChannel is deprecated, " + "use 'fqdn' instead", DeprecationWarning, 2) + return self.fqdn + @__fqdn.setter + def __fqdn(self, value): + warn("Setting __fqdn attribute on SMTPChannel is deprecated, " + "set 'fqdn' instead", DeprecationWarning, 2) + self.fqdn = value + + @property + def __peer(self): + warn("Access to __peer attribute on SMTPChannel is deprecated, " + "use 'peer' instead", DeprecationWarning, 2) + return self.peer + @__peer.setter + def __peer(self, value): + warn("Setting __peer attribute on SMTPChannel is deprecated, " + "set 'peer' instead", DeprecationWarning, 2) + self.peer = value + + @property + def __conn(self): + warn("Access to __conn attribute on SMTPChannel is deprecated, " + "use 'conn' instead", DeprecationWarning, 2) + return self.conn + @__conn.setter + def __conn(self, value): + warn("Setting __conn attribute on SMTPChannel is deprecated, " + "set 'conn' instead", DeprecationWarning, 2) + self.conn = value + + @property + def __addr(self): + warn("Access to __addr attribute on SMTPChannel is deprecated, " + "use 'addr' instead", DeprecationWarning, 2) + return self.addr + @__addr.setter + def __addr(self, value): + warn("Setting __addr attribute on SMTPChannel is deprecated, " + "set 'addr' instead", DeprecationWarning, 2) + self.addr = value + + # Overrides base class for convenience. + def push(self, msg): + asynchat.async_chat.push(self, bytes( + msg + '\r\n', 'utf-8' if self.require_SMTPUTF8 else 'ascii')) + + # Implementation of base class abstract method + def collect_incoming_data(self, data): + limit = None + if self.smtp_state == self.COMMAND: + limit = self.max_command_size_limit + elif self.smtp_state == self.DATA: + limit = self.data_size_limit + if limit and self.num_bytes > limit: + return + elif limit: + self.num_bytes += len(data) + if self._decode_data: + self.received_lines.append(str(data, 'utf-8')) + else: + self.received_lines.append(data) + + # Implementation of base class abstract method + def found_terminator(self): + line = self._emptystring.join(self.received_lines) + print('Data:', repr(line), file=DEBUGSTREAM) + self.received_lines = [] + if self.smtp_state == self.COMMAND: + sz, self.num_bytes = self.num_bytes, 0 + if not line: + self.push('500 Error: bad syntax') + return + if not self._decode_data: + line = str(line, 'utf-8') + i = line.find(' ') + if i < 0: + command = line.upper() + arg = None + else: + command = line[:i].upper() + arg = line[i+1:].strip() + max_sz = (self.command_size_limits[command] + if self.extended_smtp else self.command_size_limit) + if sz > max_sz: + self.push('500 Error: line too long') + return + method = getattr(self, 'smtp_' + command, None) + if not method: + self.push('500 Error: command "%s" not recognized' % command) + return + method(arg) + return + else: + if self.smtp_state != self.DATA: + self.push('451 Internal confusion') + self.num_bytes = 0 + return + if self.data_size_limit and self.num_bytes > self.data_size_limit: + self.push('552 Error: Too much mail data') + self.num_bytes = 0 + return + # Remove extraneous carriage returns and de-transparency according + # to RFC 5321, Section 4.5.2. + data = [] + for text in line.split(self._linesep): + if text and text[0] == self._dotsep: + data.append(text[1:]) + else: + data.append(text) + self.received_data = self._newline.join(data) + args = (self.peer, self.mailfrom, self.rcpttos, self.received_data) + kwargs = {} + if not self._decode_data: + kwargs = { + 'mail_options': self.mail_options, + 'rcpt_options': self.rcpt_options, + } + status = self.smtp_server.process_message(*args, **kwargs) + self._set_post_data_state() + if not status: + self.push('250 OK') + else: + self.push(status) + + # SMTP and ESMTP commands + def smtp_HELO(self, arg): + if not arg: + self.push('501 Syntax: HELO hostname') + return + # See issue #21783 for a discussion of this behavior. + if self.seen_greeting: + self.push('503 Duplicate HELO/EHLO') + return + self._set_rset_state() + self.seen_greeting = arg + self.push('250 %s' % self.fqdn) + + def smtp_EHLO(self, arg): + if not arg: + self.push('501 Syntax: EHLO hostname') + return + # See issue #21783 for a discussion of this behavior. + if self.seen_greeting: + self.push('503 Duplicate HELO/EHLO') + return + self._set_rset_state() + self.seen_greeting = arg + self.extended_smtp = True + self.push('250-%s' % self.fqdn) + if self.data_size_limit: + self.push('250-SIZE %s' % self.data_size_limit) + self.command_size_limits['MAIL'] += 26 + if not self._decode_data: + self.push('250-8BITMIME') + if self.enable_SMTPUTF8: + self.push('250-SMTPUTF8') + self.command_size_limits['MAIL'] += 10 + self.push('250 HELP') + + def smtp_NOOP(self, arg): + if arg: + self.push('501 Syntax: NOOP') + else: + self.push('250 OK') + + def smtp_QUIT(self, arg): + # args is ignored + self.push('221 Bye') + self.close_when_done() + + def _strip_command_keyword(self, keyword, arg): + keylen = len(keyword) + if arg[:keylen].upper() == keyword: + return arg[keylen:].strip() + return '' + + def _getaddr(self, arg): + if not arg: + return '', '' + if arg.lstrip().startswith('<'): + address, rest = get_angle_addr(arg) + else: + address, rest = get_addr_spec(arg) + if not address: + return address, rest + return address.addr_spec, rest + + def _getparams(self, params): + # Return params as dictionary. Return None if not all parameters + # appear to be syntactically valid according to RFC 1869. + result = {} + for param in params: + param, eq, value = param.partition('=') + if not param.isalnum() or eq and not value: + return None + result[param] = value if eq else True + return result + + def smtp_HELP(self, arg): + if arg: + extended = ' [SP ]' + lc_arg = arg.upper() + if lc_arg == 'EHLO': + self.push('250 Syntax: EHLO hostname') + elif lc_arg == 'HELO': + self.push('250 Syntax: HELO hostname') + elif lc_arg == 'MAIL': + msg = '250 Syntax: MAIL FROM:
' + if self.extended_smtp: + msg += extended + self.push(msg) + elif lc_arg == 'RCPT': + msg = '250 Syntax: RCPT TO:
' + if self.extended_smtp: + msg += extended + self.push(msg) + elif lc_arg == 'DATA': + self.push('250 Syntax: DATA') + elif lc_arg == 'RSET': + self.push('250 Syntax: RSET') + elif lc_arg == 'NOOP': + self.push('250 Syntax: NOOP') + elif lc_arg == 'QUIT': + self.push('250 Syntax: QUIT') + elif lc_arg == 'VRFY': + self.push('250 Syntax: VRFY
') + else: + self.push('501 Supported commands: EHLO HELO MAIL RCPT ' + 'DATA RSET NOOP QUIT VRFY') + else: + self.push('250 Supported commands: EHLO HELO MAIL RCPT DATA ' + 'RSET NOOP QUIT VRFY') + + def smtp_VRFY(self, arg): + if arg: + address, params = self._getaddr(arg) + if address: + self.push('252 Cannot VRFY user, but will accept message ' + 'and attempt delivery') + else: + self.push('502 Could not VRFY %s' % arg) + else: + self.push('501 Syntax: VRFY
') + + def smtp_MAIL(self, arg): + if not self.seen_greeting: + self.push('503 Error: send HELO first') + return + print('===> MAIL', arg, file=DEBUGSTREAM) + syntaxerr = '501 Syntax: MAIL FROM:
' + if self.extended_smtp: + syntaxerr += ' [SP ]' + if arg is None: + self.push(syntaxerr) + return + arg = self._strip_command_keyword('FROM:', arg) + address, params = self._getaddr(arg) + if not address: + self.push(syntaxerr) + return + if not self.extended_smtp and params: + self.push(syntaxerr) + return + if self.mailfrom: + self.push('503 Error: nested MAIL command') + return + self.mail_options = params.upper().split() + params = self._getparams(self.mail_options) + if params is None: + self.push(syntaxerr) + return + if not self._decode_data: + body = params.pop('BODY', '7BIT') + if body not in ['7BIT', '8BITMIME']: + self.push('501 Error: BODY can only be one of 7BIT, 8BITMIME') + return + if self.enable_SMTPUTF8: + smtputf8 = params.pop('SMTPUTF8', False) + if smtputf8 is True: + self.require_SMTPUTF8 = True + elif smtputf8 is not False: + self.push('501 Error: SMTPUTF8 takes no arguments') + return + size = params.pop('SIZE', None) + if size: + if not size.isdigit(): + self.push(syntaxerr) + return + elif self.data_size_limit and int(size) > self.data_size_limit: + self.push('552 Error: message size exceeds fixed maximum message size') + return + if len(params.keys()) > 0: + self.push('555 MAIL FROM parameters not recognized or not implemented') + return + self.mailfrom = address + print('sender:', self.mailfrom, file=DEBUGSTREAM) + self.push('250 OK') + + def smtp_RCPT(self, arg): + if not self.seen_greeting: + self.push('503 Error: send HELO first'); + return + print('===> RCPT', arg, file=DEBUGSTREAM) + if not self.mailfrom: + self.push('503 Error: need MAIL command') + return + syntaxerr = '501 Syntax: RCPT TO:
' + if self.extended_smtp: + syntaxerr += ' [SP ]' + if arg is None: + self.push(syntaxerr) + return + arg = self._strip_command_keyword('TO:', arg) + address, params = self._getaddr(arg) + if not address: + self.push(syntaxerr) + return + if not self.extended_smtp and params: + self.push(syntaxerr) + return + self.rcpt_options = params.upper().split() + params = self._getparams(self.rcpt_options) + if params is None: + self.push(syntaxerr) + return + # XXX currently there are no options we recognize. + if len(params.keys()) > 0: + self.push('555 RCPT TO parameters not recognized or not implemented') + return + self.rcpttos.append(address) + print('recips:', self.rcpttos, file=DEBUGSTREAM) + self.push('250 OK') + + def smtp_RSET(self, arg): + if arg: + self.push('501 Syntax: RSET') + return + self._set_rset_state() + self.push('250 OK') + + def smtp_DATA(self, arg): + if not self.seen_greeting: + self.push('503 Error: send HELO first'); + return + if not self.rcpttos: + self.push('503 Error: need RCPT command') + return + if arg: + self.push('501 Syntax: DATA') + return + self.smtp_state = self.DATA + self.set_terminator(b'\r\n.\r\n') + self.push('354 End data with .') + + # Commands that have not been implemented + def smtp_EXPN(self, arg): + self.push('502 EXPN not implemented') + + +class SMTPServer(asyncore.dispatcher): + # SMTPChannel class to use for managing client connections + channel_class = SMTPChannel + + def __init__(self, localaddr, remoteaddr, + data_size_limit=DATA_SIZE_DEFAULT, map=None, + enable_SMTPUTF8=False, decode_data=False): + self._localaddr = localaddr + self._remoteaddr = remoteaddr + self.data_size_limit = data_size_limit + self.enable_SMTPUTF8 = enable_SMTPUTF8 + self._decode_data = decode_data + if enable_SMTPUTF8 and decode_data: + raise ValueError("decode_data and enable_SMTPUTF8 cannot" + " be set to True at the same time") + asyncore.dispatcher.__init__(self, map=map) + try: + gai_results = socket.getaddrinfo(*localaddr, + type=socket.SOCK_STREAM) + self.create_socket(gai_results[0][0], gai_results[0][1]) + # try to re-use a server port if possible + self.set_reuse_addr() + self.bind(localaddr) + self.listen(5) + except: + self.close() + raise + else: + print('%s started at %s\n\tLocal addr: %s\n\tRemote addr:%s' % ( + self.__class__.__name__, time.ctime(time.time()), + localaddr, remoteaddr), file=DEBUGSTREAM) + + def handle_accepted(self, conn, addr): + print('Incoming connection from %s' % repr(addr), file=DEBUGSTREAM) + channel = self.channel_class(self, + conn, + addr, + self.data_size_limit, + self._map, + self.enable_SMTPUTF8, + self._decode_data) + + # API for "doing something useful with the message" + def process_message(self, peer, mailfrom, rcpttos, data, **kwargs): + """Override this abstract method to handle messages from the client. + + peer is a tuple containing (ipaddr, port) of the client that made the + socket connection to our smtp port. + + mailfrom is the raw address the client claims the message is coming + from. + + rcpttos is a list of raw addresses the client wishes to deliver the + message to. + + data is a string containing the entire full text of the message, + headers (if supplied) and all. It has been `de-transparencied' + according to RFC 821, Section 4.5.2. In other words, a line + containing a `.' followed by other text has had the leading dot + removed. + + kwargs is a dictionary containing additional information. It is + empty if decode_data=True was given as init parameter, otherwise + it will contain the following keys: + 'mail_options': list of parameters to the mail command. All + elements are uppercase strings. Example: + ['BODY=8BITMIME', 'SMTPUTF8']. + 'rcpt_options': same, for the rcpt command. + + This function should return None for a normal `250 Ok' response; + otherwise, it should return the desired response string in RFC 821 + format. + + """ + raise NotImplementedError + + +class DebuggingServer(SMTPServer): + + def _print_message_content(self, peer, data): + inheaders = 1 + lines = data.splitlines() + for line in lines: + # headers first + if inheaders and not line: + peerheader = 'X-Peer: ' + peer[0] + if not isinstance(data, str): + # decoded_data=false; make header match other binary output + peerheader = repr(peerheader.encode('utf-8')) + print(peerheader) + inheaders = 0 + if not isinstance(data, str): + # Avoid spurious 'str on bytes instance' warning. + line = repr(line) + print(line) + + def process_message(self, peer, mailfrom, rcpttos, data, **kwargs): + print('---------- MESSAGE FOLLOWS ----------') + if kwargs: + if kwargs.get('mail_options'): + print('mail options: %s' % kwargs['mail_options']) + if kwargs.get('rcpt_options'): + print('rcpt options: %s\n' % kwargs['rcpt_options']) + self._print_message_content(peer, data) + print('------------ END MESSAGE ------------') + + +class PureProxy(SMTPServer): + def __init__(self, *args, **kwargs): + if 'enable_SMTPUTF8' in kwargs and kwargs['enable_SMTPUTF8']: + raise ValueError("PureProxy does not support SMTPUTF8.") + super(PureProxy, self).__init__(*args, **kwargs) + + def process_message(self, peer, mailfrom, rcpttos, data): + lines = data.split('\n') + # Look for the last header + i = 0 + for line in lines: + if not line: + break + i += 1 + lines.insert(i, 'X-Peer: %s' % peer[0]) + data = NEWLINE.join(lines) + refused = self._deliver(mailfrom, rcpttos, data) + # TBD: what to do with refused addresses? + print('we got some refusals:', refused, file=DEBUGSTREAM) + + def _deliver(self, mailfrom, rcpttos, data): + import smtplib + refused = {} + try: + s = smtplib.SMTP() + s.connect(self._remoteaddr[0], self._remoteaddr[1]) + try: + refused = s.sendmail(mailfrom, rcpttos, data) + finally: + s.quit() + except smtplib.SMTPRecipientsRefused as e: + print('got SMTPRecipientsRefused', file=DEBUGSTREAM) + refused = e.recipients + except (OSError, smtplib.SMTPException) as e: + print('got', e.__class__, file=DEBUGSTREAM) + # All recipients were refused. If the exception had an associated + # error code, use it. Otherwise,fake it with a non-triggering + # exception code. + errcode = getattr(e, 'smtp_code', -1) + errmsg = getattr(e, 'smtp_error', 'ignore') + for r in rcpttos: + refused[r] = (errcode, errmsg) + return refused + + +class Options: + setuid = True + classname = 'PureProxy' + size_limit = None + enable_SMTPUTF8 = False + + +def parseargs(): + global DEBUGSTREAM + try: + opts, args = getopt.getopt( + sys.argv[1:], 'nVhc:s:du', + ['class=', 'nosetuid', 'version', 'help', 'size=', 'debug', + 'smtputf8']) + except getopt.error as e: + usage(1, e) + + options = Options() + for opt, arg in opts: + if opt in ('-h', '--help'): + usage(0) + elif opt in ('-V', '--version'): + print(__version__) + sys.exit(0) + elif opt in ('-n', '--nosetuid'): + options.setuid = False + elif opt in ('-c', '--class'): + options.classname = arg + elif opt in ('-d', '--debug'): + DEBUGSTREAM = sys.stderr + elif opt in ('-u', '--smtputf8'): + options.enable_SMTPUTF8 = True + elif opt in ('-s', '--size'): + try: + int_size = int(arg) + options.size_limit = int_size + except: + print('Invalid size: ' + arg, file=sys.stderr) + sys.exit(1) + + # parse the rest of the arguments + if len(args) < 1: + localspec = 'localhost:8025' + remotespec = 'localhost:25' + elif len(args) < 2: + localspec = args[0] + remotespec = 'localhost:25' + elif len(args) < 3: + localspec = args[0] + remotespec = args[1] + else: + usage(1, 'Invalid arguments: %s' % COMMASPACE.join(args)) + + # split into host/port pairs + i = localspec.find(':') + if i < 0: + usage(1, 'Bad local spec: %s' % localspec) + options.localhost = localspec[:i] + try: + options.localport = int(localspec[i+1:]) + except ValueError: + usage(1, 'Bad local port: %s' % localspec) + i = remotespec.find(':') + if i < 0: + usage(1, 'Bad remote spec: %s' % remotespec) + options.remotehost = remotespec[:i] + try: + options.remoteport = int(remotespec[i+1:]) + except ValueError: + usage(1, 'Bad remote port: %s' % remotespec) + return options + + +if __name__ == '__main__': + options = parseargs() + # Become nobody + classname = options.classname + if "." in classname: + lastdot = classname.rfind(".") + mod = __import__(classname[:lastdot], globals(), locals(), [""]) + classname = classname[lastdot+1:] + else: + import __main__ as mod + class_ = getattr(mod, classname) + proxy = class_((options.localhost, options.localport), + (options.remotehost, options.remoteport), + options.size_limit, enable_SMTPUTF8=options.enable_SMTPUTF8) + if options.setuid: + try: + import pwd + except ImportError: + print('Cannot import module "pwd"; try running with -n option.', file=sys.stderr) + sys.exit(1) + nobody = pwd.getpwnam('nobody')[2] + try: + os.setuid(nobody) + except PermissionError: + print('Cannot setuid "nobody"; try running with -n option.', file=sys.stderr) + sys.exit(1) + try: + asyncore.loop() + except KeyboardInterrupt: + pass diff --git a/Lib/test/test__colorize.py b/Lib/test/test__colorize.py new file mode 100644 index 0000000000..056a5306ce --- /dev/null +++ b/Lib/test/test__colorize.py @@ -0,0 +1,135 @@ +import contextlib +import io +import sys +import unittest +import unittest.mock +import _colorize +from test.support.os_helper import EnvironmentVarGuard + + +@contextlib.contextmanager +def clear_env(): + with EnvironmentVarGuard() as mock_env: + for var in "FORCE_COLOR", "NO_COLOR", "PYTHON_COLORS": + mock_env.unset(var) + yield mock_env + + +def supports_virtual_terminal(): + if sys.platform == "win32": + return unittest.mock.patch("nt._supports_virtual_terminal", return_value=True) + else: + return contextlib.nullcontext() + + +class TestColorizeFunction(unittest.TestCase): + def test_colorized_detection_checks_for_environment_variables(self): + def check(env, fallback, expected): + with (self.subTest(env=env, fallback=fallback), + clear_env() as mock_env): + mock_env.update(env) + isatty_mock.return_value = fallback + stdout_mock.isatty.return_value = fallback + self.assertEqual(_colorize.can_colorize(), expected) + + with (unittest.mock.patch("os.isatty") as isatty_mock, + unittest.mock.patch("sys.stdout") as stdout_mock, + supports_virtual_terminal()): + stdout_mock.fileno.return_value = 1 + + for fallback in False, True: + check({}, fallback, fallback) + check({'TERM': 'dumb'}, fallback, False) + check({'TERM': 'xterm'}, fallback, fallback) + check({'TERM': ''}, fallback, fallback) + check({'FORCE_COLOR': '1'}, fallback, True) + check({'FORCE_COLOR': '0'}, fallback, True) + check({'FORCE_COLOR': ''}, fallback, fallback) + check({'NO_COLOR': '1'}, fallback, False) + check({'NO_COLOR': '0'}, fallback, False) + check({'NO_COLOR': ''}, fallback, fallback) + + check({'TERM': 'dumb', 'FORCE_COLOR': '1'}, False, True) + check({'FORCE_COLOR': '1', 'NO_COLOR': '1'}, True, False) + + for ignore_environment in False, True: + # Simulate running with or without `-E`. + flags = unittest.mock.MagicMock(ignore_environment=ignore_environment) + with unittest.mock.patch("sys.flags", flags): + check({'PYTHON_COLORS': '1'}, True, True) + check({'PYTHON_COLORS': '1'}, False, not ignore_environment) + check({'PYTHON_COLORS': '0'}, True, ignore_environment) + check({'PYTHON_COLORS': '0'}, False, False) + for fallback in False, True: + check({'PYTHON_COLORS': 'x'}, fallback, fallback) + check({'PYTHON_COLORS': ''}, fallback, fallback) + + check({'TERM': 'dumb', 'PYTHON_COLORS': '1'}, False, not ignore_environment) + check({'NO_COLOR': '1', 'PYTHON_COLORS': '1'}, False, not ignore_environment) + check({'FORCE_COLOR': '1', 'PYTHON_COLORS': '0'}, True, ignore_environment) + + @unittest.skipUnless(sys.platform == "win32", "requires Windows") + def test_colorized_detection_checks_on_windows(self): + with (clear_env(), + unittest.mock.patch("os.isatty") as isatty_mock, + unittest.mock.patch("sys.stdout") as stdout_mock, + supports_virtual_terminal() as vt_mock): + stdout_mock.fileno.return_value = 1 + isatty_mock.return_value = True + stdout_mock.isatty.return_value = True + + vt_mock.return_value = True + self.assertEqual(_colorize.can_colorize(), True) + vt_mock.return_value = False + self.assertEqual(_colorize.can_colorize(), False) + import nt + del nt._supports_virtual_terminal + self.assertEqual(_colorize.can_colorize(), False) + + def test_colorized_detection_checks_for_std_streams(self): + with (clear_env(), + unittest.mock.patch("os.isatty") as isatty_mock, + unittest.mock.patch("sys.stdout") as stdout_mock, + unittest.mock.patch("sys.stderr") as stderr_mock, + supports_virtual_terminal()): + stdout_mock.fileno.return_value = 1 + stderr_mock.fileno.side_effect = ZeroDivisionError + stderr_mock.isatty.side_effect = ZeroDivisionError + + isatty_mock.return_value = True + stdout_mock.isatty.return_value = True + self.assertEqual(_colorize.can_colorize(), True) + + isatty_mock.return_value = False + stdout_mock.isatty.return_value = False + self.assertEqual(_colorize.can_colorize(), False) + + def test_colorized_detection_checks_for_file(self): + with clear_env(), supports_virtual_terminal(): + + with unittest.mock.patch("os.isatty") as isatty_mock: + file = unittest.mock.MagicMock() + file.fileno.return_value = 1 + isatty_mock.return_value = True + self.assertEqual(_colorize.can_colorize(file=file), True) + isatty_mock.return_value = False + self.assertEqual(_colorize.can_colorize(file=file), False) + + # No file.fileno. + with unittest.mock.patch("os.isatty", side_effect=ZeroDivisionError): + file = unittest.mock.MagicMock(spec=['isatty']) + file.isatty.return_value = True + self.assertEqual(_colorize.can_colorize(file=file), False) + + # file.fileno() raises io.UnsupportedOperation. + with unittest.mock.patch("os.isatty", side_effect=ZeroDivisionError): + file = unittest.mock.MagicMock() + file.fileno.side_effect = io.UnsupportedOperation + file.isatty.return_value = True + self.assertEqual(_colorize.can_colorize(file=file), True) + file.isatty.return_value = False + self.assertEqual(_colorize.can_colorize(file=file), False) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_charmapcodec.py b/Lib/test/test_charmapcodec.py index e69f1c6e4b..8ea75d9129 100644 --- a/Lib/test/test_charmapcodec.py +++ b/Lib/test/test_charmapcodec.py @@ -26,7 +26,6 @@ def codec_search_function(encoding): codecname = 'testcodec' class CharmapCodecTest(unittest.TestCase): - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_constructorx(self): self.assertEqual(str(b'abc', codecname), 'abc') self.assertEqual(str(b'xdef', codecname), 'abcdef') @@ -43,14 +42,12 @@ def test_encodex(self): self.assertEqual('dxf'.encode(codecname), b'dabcf') self.assertEqual('dxfx'.encode(codecname), b'dabcfabc') - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_constructory(self): self.assertEqual(str(b'ydef', codecname), 'def') self.assertEqual(str(b'defy', codecname), 'def') self.assertEqual(str(b'dyf', codecname), 'df') self.assertEqual(str(b'dyfy', codecname), 'df') - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_maptoundefined(self): self.assertRaises(UnicodeError, str, b'abc\001', codecname) diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index 085b800b6d..f29e91e088 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -1827,7 +1827,6 @@ def test_decode(self): self.assertEqual(codecs.decode(b'[\xff]', 'ascii', errors='ignore'), '[]') - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_encode(self): self.assertEqual(codecs.encode('\xe4\xf6\xfc', 'latin-1'), b'\xe4\xf6\xfc') @@ -1846,7 +1845,6 @@ def test_register(self): self.assertRaises(TypeError, codecs.register) self.assertRaises(TypeError, codecs.register, 42) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; AttributeError: module '_winapi' has no attribute 'GetACP'") def test_unregister(self): name = "nonexistent_codec_name" search_function = mock.Mock() @@ -1859,28 +1857,23 @@ def test_unregister(self): self.assertRaises(LookupError, codecs.lookup, name) search_function.assert_not_called() - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_lookup(self): self.assertRaises(TypeError, codecs.lookup) self.assertRaises(LookupError, codecs.lookup, "__spam__") self.assertRaises(LookupError, codecs.lookup, " ") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_getencoder(self): self.assertRaises(TypeError, codecs.getencoder) self.assertRaises(LookupError, codecs.getencoder, "__spam__") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_getdecoder(self): self.assertRaises(TypeError, codecs.getdecoder) self.assertRaises(LookupError, codecs.getdecoder, "__spam__") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_getreader(self): self.assertRaises(TypeError, codecs.getreader) self.assertRaises(LookupError, codecs.getreader, "__spam__") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_getwriter(self): self.assertRaises(TypeError, codecs.getwriter) self.assertRaises(LookupError, codecs.getwriter, "__spam__") @@ -1939,7 +1932,6 @@ def test_undefined(self): self.assertRaises(UnicodeError, codecs.decode, b'abc', 'undefined', errors) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_file_closes_if_lookup_error_raised(self): mock_open = mock.mock_open() with mock.patch('builtins.open', mock_open) as file: @@ -3287,7 +3279,6 @@ def test_multiple_args(self): self.check_note(RuntimeError('a', 'b', 'c'), msg_re) # http://bugs.python.org/issue19609 - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_codec_lookup_failure(self): msg = "^unknown encoding: {}$".format(self.codec_name) with self.assertRaisesRegex(LookupError, msg): @@ -3523,8 +3514,6 @@ def test_incremental(self): False) self.assertEqual(decoded, ('abc', 3)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_mbcs_alias(self): # Check that looking up our 'default' codepage will return # mbcs when we don't have a more specific one available diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 3989b7d674..bff85a7ec2 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -1369,10 +1369,12 @@ class Inner(Enum): [Outer.a, Outer.b, Outer.Inner], ) + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipIf( - python_version < (3, 13), - 'inner classes are still members', - ) + python_version < (3, 13), + 'inner classes are still members', + ) def test_nested_classes_in_enum_are_not_members(self): """Support locally-defined nested classes.""" class Outer(Enum): @@ -4555,20 +4557,24 @@ class Color(Enum): self.assertEqual(Color.green.value, 3) self.assertEqual(Color.yellow.value, 4) + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipIf( - python_version < (3, 13), - 'mixed types with auto() will raise in 3.13', - ) + python_version < (3, 13), + 'inner classes are still members', + ) def test_auto_garbage_fail(self): with self.assertRaisesRegex(TypeError, 'will require all values to be sortable'): class Color(Enum): red = 'red' blue = auto() + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipIf( - python_version < (3, 13), - 'mixed types with auto() will raise in 3.13', - ) + python_version < (3, 13), + 'inner classes are still members', + ) def test_auto_garbage_corrected_fail(self): with self.assertRaisesRegex(TypeError, 'will require all values to be sortable'): class Color(Enum): @@ -4598,9 +4604,9 @@ def _generate_next_value_(name, start, count, last): self.assertEqual(Color.blue.value, 'blue') @unittest.skipIf( - python_version < (3, 13), - 'auto() will return highest value + 1 in 3.13', - ) + python_version < (3, 13), + 'inner classes are still members', + ) def test_auto_with_aliases(self): class Color(Enum): red = auto() diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py index 353c1ea9b7..19c17af596 100644 --- a/Lib/test/test_float.py +++ b/Lib/test/test_float.py @@ -133,7 +133,7 @@ def check(s): with self.assertRaises(ValueError, msg='float(%r)' % (s,)) as cm: float(s) self.assertEqual(str(cm.exception), - 'could not convert string to float: %r' % (s,)) + 'could not convert string to float: %r' % (s,)) check('\xbd') check('123\xbd') @@ -155,7 +155,9 @@ def check(s): # non-UTF-8 byte string check(b'123\xa0') - @support.run_with_locale('LC_NUMERIC', 'fr_FR', 'de_DE') + # TODO: RUSTPYTHON + @unittest.skip("RustPython panics on this") + @support.run_with_locale('LC_NUMERIC', 'fr_FR', 'de_DE', '') def test_float_with_comma(self): # set locale to something that doesn't use '.' for the decimal point # float must not accept the locale specific decimal point but @@ -290,11 +292,11 @@ def test_is_integer(self): def test_floatasratio(self): for f, ratio in [ - (0.875, (7, 8)), - (-0.875, (-7, 8)), - (0.0, (0, 1)), - (11.5, (23, 2)), - ]: + (0.875, (7, 8)), + (-0.875, (-7, 8)), + (0.0, (0, 1)), + (11.5, (23, 2)), + ]: self.assertEqual(f.as_integer_ratio(), ratio) for i in range(10000): @@ -337,7 +339,7 @@ def test_float_containment(self): self.assertTrue((f,) == (f,), "(%r,) != (%r,)" % (f, f)) self.assertTrue({f} == {f}, "{%r} != {%r}" % (f, f)) self.assertTrue({f : None} == {f: None}, "{%r : None} != " - "{%r : None}" % (f, f)) + "{%r : None}" % (f, f)) # identical containers l, t, s, d = [f], (f,), {f}, {f: None} @@ -667,8 +669,6 @@ def test_float_specials_do_unpack(self): ('>> [None for i in range(10)] [None, None, None, None, None, None, None, None, None, None] -########### Tests for various scoping corner cases ############ - -Return lambdas that use the iteration variable as a default argument - - >>> items = [(lambda i=i: i) for i in range(5)] - >>> [x() for x in items] - [0, 1, 2, 3, 4] - -Same again, only this time as a closure variable - - >>> items = [(lambda: i) for i in range(5)] - >>> [x() for x in items] - [4, 4, 4, 4, 4] - -Another way to test that the iteration variable is local to the list comp - - >>> items = [(lambda: i) for i in range(5)] - >>> i = 20 - >>> [x() for x in items] - [4, 4, 4, 4, 4] - -And confirm that a closure can jump over the list comp scope - - >>> items = [(lambda: y) for i in range(5)] - >>> y = 2 - >>> [x() for x in items] - [2, 2, 2, 2, 2] - -We also repeat each of the above scoping tests inside a function - - >>> def test_func(): - ... items = [(lambda i=i: i) for i in range(5)] - ... return [x() for x in items] - >>> test_func() - [0, 1, 2, 3, 4] - - >>> def test_func(): - ... items = [(lambda: i) for i in range(5)] - ... return [x() for x in items] - >>> test_func() - [4, 4, 4, 4, 4] - - >>> def test_func(): - ... items = [(lambda: i) for i in range(5)] - ... i = 20 - ... return [x() for x in items] - >>> test_func() - [4, 4, 4, 4, 4] - - >>> def test_func(): - ... items = [(lambda: y) for i in range(5)] - ... y = 2 - ... return [x() for x in items] - >>> test_func() - [2, 2, 2, 2, 2] - """ +class ListComprehensionTest(unittest.TestCase): + def _check_in_scopes(self, code, outputs=None, ns=None, scopes=None, raises=(), + exec_func=exec): + code = textwrap.dedent(code) + scopes = scopes or ["module", "class", "function"] + for scope in scopes: + with self.subTest(scope=scope): + if scope == "class": + newcode = textwrap.dedent(""" + class _C: + {code} + """).format(code=textwrap.indent(code, " ")) + def get_output(moddict, name): + return getattr(moddict["_C"], name) + elif scope == "function": + newcode = textwrap.dedent(""" + def _f(): + {code} + return locals() + _out = _f() + """).format(code=textwrap.indent(code, " ")) + def get_output(moddict, name): + return moddict["_out"][name] + else: + newcode = code + def get_output(moddict, name): + return moddict[name] + newns = ns.copy() if ns else {} + try: + exec_func(newcode, newns) + except raises as e: + # We care about e.g. NameError vs UnboundLocalError + self.assertIs(type(e), raises) + else: + for k, v in (outputs or {}).items(): + self.assertEqual(get_output(newns, k), v, k) + + def test_lambdas_with_iteration_var_as_default(self): + code = """ + items = [(lambda i=i: i) for i in range(5)] + y = [x() for x in items] + """ + outputs = {"y": [0, 1, 2, 3, 4]} + self._check_in_scopes(code, outputs) + + def test_lambdas_with_free_var(self): + code = """ + items = [(lambda: i) for i in range(5)] + y = [x() for x in items] + """ + outputs = {"y": [4, 4, 4, 4, 4]} + self._check_in_scopes(code, outputs) + + def test_class_scope_free_var_with_class_cell(self): + class C: + def method(self): + super() + return __class__ + items = [(lambda: i) for i in range(5)] + y = [x() for x in items] + + self.assertEqual(C.y, [4, 4, 4, 4, 4]) + self.assertIs(C().method(), C) + + def test_references_super(self): + code = """ + res = [super for x in [1]] + """ + self._check_in_scopes(code, outputs={"res": [super]}) + + def test_references___class__(self): + code = """ + res = [__class__ for x in [1]] + """ + self._check_in_scopes(code, raises=NameError) + + def test_references___class___defined(self): + code = """ + __class__ = 2 + res = [__class__ for x in [1]] + """ + self._check_in_scopes( + code, outputs={"res": [2]}, scopes=["module", "function"]) + self._check_in_scopes(code, raises=NameError, scopes=["class"]) + + def test_references___class___enclosing(self): + code = """ + __class__ = 2 + class C: + res = [__class__ for x in [1]] + res = C.res + """ + self._check_in_scopes(code, raises=NameError) + + def test_super_and_class_cell_in_sibling_comps(self): + code = """ + [super for _ in [1]] + [__class__ for _ in [1]] + """ + self._check_in_scopes(code, raises=NameError) + + def test_inner_cell_shadows_outer(self): + code = """ + items = [(lambda: i) for i in range(5)] + i = 20 + y = [x() for x in items] + """ + outputs = {"y": [4, 4, 4, 4, 4], "i": 20} + self._check_in_scopes(code, outputs) + + def test_inner_cell_shadows_outer_no_store(self): + code = """ + def f(x): + return [lambda: x for x in range(x)], x + fns, x = f(2) + y = [fn() for fn in fns] + """ + outputs = {"y": [1, 1], "x": 2} + self._check_in_scopes(code, outputs) + + def test_closure_can_jump_over_comp_scope(self): + code = """ + items = [(lambda: y) for i in range(5)] + y = 2 + z = [x() for x in items] + """ + outputs = {"z": [2, 2, 2, 2, 2]} + self._check_in_scopes(code, outputs, scopes=["module", "function"]) + + def test_cell_inner_free_outer(self): + code = """ + def f(): + return [lambda: x for x in (x, [1])[1]] + x = ... + y = [fn() for fn in f()] + """ + outputs = {"y": [1]} + self._check_in_scopes(code, outputs, scopes=["module", "function"]) + + def test_free_inner_cell_outer(self): + code = """ + g = 2 + def f(): + return g + y = [g for x in [1]] + """ + outputs = {"y": [2]} + self._check_in_scopes(code, outputs, scopes=["module", "function"]) + self._check_in_scopes(code, scopes=["class"], raises=NameError) + + def test_inner_cell_shadows_outer_redefined(self): + code = """ + y = 10 + items = [(lambda: y) for y in range(5)] + x = y + y = 20 + out = [z() for z in items] + """ + outputs = {"x": 10, "out": [4, 4, 4, 4, 4]} + self._check_in_scopes(code, outputs) + + def test_shadows_outer_cell(self): + code = """ + def inner(): + return g + [g for g in range(5)] + x = inner() + """ + outputs = {"x": -1} + self._check_in_scopes(code, outputs, ns={"g": -1}) + + def test_explicit_global(self): + code = """ + global g + x = g + g = 2 + items = [g for g in [1]] + y = g + """ + outputs = {"x": 1, "y": 2, "items": [1]} + self._check_in_scopes(code, outputs, ns={"g": 1}) + + def test_explicit_global_2(self): + code = """ + global g + x = g + g = 2 + items = [g for x in [1]] + y = g + """ + outputs = {"x": 1, "y": 2, "items": [2]} + self._check_in_scopes(code, outputs, ns={"g": 1}) + + def test_explicit_global_3(self): + code = """ + global g + fns = [lambda: g for g in [2]] + items = [fn() for fn in fns] + """ + outputs = {"items": [2]} + self._check_in_scopes(code, outputs, ns={"g": 1}) + + def test_assignment_expression(self): + code = """ + x = -1 + items = [(x:=y) for y in range(3)] + """ + outputs = {"x": 2} + # assignment expression in comprehension is disallowed in class scope + self._check_in_scopes(code, outputs, scopes=["module", "function"]) + + def test_free_var_in_comp_child(self): + code = """ + lst = range(3) + funcs = [lambda: x for x in lst] + inc = [x + 1 for x in lst] + [x for x in inc] + x = funcs[0]() + """ + outputs = {"x": 2} + self._check_in_scopes(code, outputs) + + def test_shadow_with_free_and_local(self): + code = """ + lst = range(3) + x = -1 + funcs = [lambda: x for x in lst] + items = [x + 1 for x in lst] + """ + outputs = {"x": -1} + self._check_in_scopes(code, outputs) + + def test_shadow_comp_iterable_name(self): + code = """ + x = [1] + y = [x for x in x] + """ + outputs = {"x": [1]} + self._check_in_scopes(code, outputs) + + def test_nested_free(self): + code = """ + x = 1 + def g(): + [x for x in range(3)] + return x + g() + """ + outputs = {"x": 1} + self._check_in_scopes(code, outputs, scopes=["module", "function"]) + + def test_introspecting_frame_locals(self): + code = """ + import sys + [i for i in range(2)] + i = 20 + sys._getframe().f_locals + """ + outputs = {"i": 20} + self._check_in_scopes(code, outputs) + + def test_nested(self): + code = """ + l = [2, 3] + y = [[x ** 2 for x in range(x)] for x in l] + """ + outputs = {"y": [[0, 1], [0, 1, 4]]} + self._check_in_scopes(code, outputs) + + def test_nested_2(self): + code = """ + l = [1, 2, 3] + x = 3 + y = [x for [x ** x for x in range(x)][x - 1] in l] + """ + outputs = {"y": [3, 3, 3]} + self._check_in_scopes(code, outputs, scopes=["module", "function"]) + self._check_in_scopes(code, scopes=["class"], raises=NameError) + + def test_nested_3(self): + code = """ + l = [(1, 2), (3, 4), (5, 6)] + y = [x for (x, [x ** x for x in range(x)][x - 1]) in l] + """ + outputs = {"y": [1, 3, 5]} + self._check_in_scopes(code, outputs) + + def test_nested_4(self): + code = """ + items = [([lambda: x for x in range(2)], lambda: x) for x in range(3)] + out = [([fn() for fn in fns], fn()) for fns, fn in items] + """ + outputs = {"out": [([1, 1], 2), ([1, 1], 2), ([1, 1], 2)]} + self._check_in_scopes(code, outputs) + + def test_nameerror(self): + code = """ + [x for x in [1]] + x + """ + + self._check_in_scopes(code, raises=NameError) + + def test_dunder_name(self): + code = """ + y = [__x for __x in [1]] + """ + outputs = {"y": [1]} + self._check_in_scopes(code, outputs) + + def test_unbound_local_after_comprehension(self): + def f(): + if False: + x = 0 + [x for x in [1]] + return x + + with self.assertRaises(UnboundLocalError): + f() + + def test_unbound_local_inside_comprehension(self): + def f(): + l = [None] + return [1 for (l[0], l) in [[1, 2]]] + + with self.assertRaises(UnboundLocalError): + f() + + def test_global_outside_cellvar_inside_plus_freevar(self): + code = """ + a = 1 + def f(): + func, = [(lambda: b) for b in [a]] + return b, func() + x = f() + """ + self._check_in_scopes( + code, {"x": (2, 1)}, ns={"b": 2}, scopes=["function", "module"]) + # inside a class, the `a = 1` assignment is not visible + self._check_in_scopes(code, raises=NameError, scopes=["class"]) + + def test_cell_in_nested_comprehension(self): + code = """ + a = 1 + def f(): + (func, inner_b), = [[lambda: b for b in c] + [b] for c in [[a]]] + return b, inner_b, func() + x = f() + """ + self._check_in_scopes( + code, {"x": (2, 2, 1)}, ns={"b": 2}, scopes=["function", "module"]) + # inside a class, the `a = 1` assignment is not visible + self._check_in_scopes(code, raises=NameError, scopes=["class"]) + + def test_name_error_in_class_scope(self): + code = """ + y = 1 + [x + y for x in range(2)] + """ + self._check_in_scopes(code, raises=NameError, scopes=["class"]) + + def test_global_in_class_scope(self): + code = """ + y = 2 + vals = [(x, y) for x in range(2)] + """ + outputs = {"vals": [(0, 1), (1, 1)]} + self._check_in_scopes(code, outputs, ns={"y": 1}, scopes=["class"]) + + def test_in_class_scope_inside_function_1(self): + code = """ + class C: + y = 2 + vals = [(x, y) for x in range(2)] + vals = C.vals + """ + outputs = {"vals": [(0, 1), (1, 1)]} + self._check_in_scopes(code, outputs, ns={"y": 1}, scopes=["function"]) + + def test_in_class_scope_inside_function_2(self): + code = """ + y = 1 + class C: + y = 2 + vals = [(x, y) for x in range(2)] + vals = C.vals + """ + outputs = {"vals": [(0, 1), (1, 1)]} + self._check_in_scopes(code, outputs, scopes=["function"]) + + def test_in_class_scope_with_global(self): + code = """ + y = 1 + class C: + global y + y = 2 + # Ensure the listcomp uses the global, not the value in the + # class namespace + locals()['y'] = 3 + vals = [(x, y) for x in range(2)] + vals = C.vals + """ + outputs = {"vals": [(0, 2), (1, 2)]} + self._check_in_scopes(code, outputs, scopes=["module", "class"]) + outputs = {"vals": [(0, 1), (1, 1)]} + self._check_in_scopes(code, outputs, scopes=["function"]) + + def test_in_class_scope_with_nonlocal(self): + code = """ + y = 1 + class C: + nonlocal y + y = 2 + # Ensure the listcomp uses the global, not the value in the + # class namespace + locals()['y'] = 3 + vals = [(x, y) for x in range(2)] + vals = C.vals + """ + outputs = {"vals": [(0, 2), (1, 2)]} + self._check_in_scopes(code, outputs, scopes=["function"]) + + def test_nested_has_free_var(self): + code = """ + items = [a for a in [1] if [a for _ in [0]]] + """ + outputs = {"items": [1]} + self._check_in_scopes(code, outputs, scopes=["class"]) + + def test_nested_free_var_not_bound_in_outer_comp(self): + code = """ + z = 1 + items = [a for a in [1] if [x for x in [1] if z]] + """ + self._check_in_scopes(code, {"items": [1]}, scopes=["module", "function"]) + self._check_in_scopes(code, {"items": []}, ns={"z": 0}, scopes=["class"]) + + def test_nested_free_var_in_iter(self): + code = """ + items = [_C for _C in [1] for [0, 1][[x for x in [1] if _C][0]] in [2]] + """ + self._check_in_scopes(code, {"items": [1]}) + + def test_nested_free_var_in_expr(self): + code = """ + items = [(_C, [x for x in [1] if _C]) for _C in [0, 1]] + """ + self._check_in_scopes(code, {"items": [(0, []), (1, [1])]}) + + def test_nested_listcomp_in_lambda(self): + code = """ + f = [(z, lambda y: [(x, y, z) for x in [3]]) for z in [1]] + (z, func), = f + out = func(2) + """ + self._check_in_scopes(code, {"z": 1, "out": [(3, 2, 1)]}) + + def test_lambda_in_iter(self): + code = """ + (func, c), = [(a, b) for b in [1] for a in [lambda : a]] + d = func() + assert d is func + # must use "a" in this scope + e = a if False else None + """ + self._check_in_scopes(code, {"c": 1, "e": None}) + + def test_assign_to_comp_iter_var_in_outer_function(self): + code = """ + a = [1 for a in [0]] + """ + self._check_in_scopes(code, {"a": [1]}, scopes=["function"]) + + def test_no_leakage_to_locals(self): + code = """ + def b(): + [a for b in [1] for _ in []] + return b, locals() + r, s = b() + x = r is b + y = list(s.keys()) + """ + self._check_in_scopes(code, {"x": True, "y": []}, scopes=["module"]) + self._check_in_scopes(code, {"x": True, "y": ["b"]}, scopes=["function"]) + self._check_in_scopes(code, raises=NameError, scopes=["class"]) + + def test_iter_var_available_in_locals(self): + code = """ + l = [1, 2] + y = 0 + items = [locals()["x"] for x in l] + items2 = [vars()["x"] for x in l] + items3 = [("x" in dir()) for x in l] + items4 = [eval("x") for x in l] + # x is available, and does not overwrite y + [exec("y = x") for x in l] + """ + self._check_in_scopes( + code, + { + "items": [1, 2], + "items2": [1, 2], + "items3": [True, True], + "items4": [1, 2], + "y": 0 + } + ) + + def test_comp_in_try_except(self): + template = """ + value = ["ab"] + result = snapshot = None + try: + result = [{func}(value) for value in value] + except: + snapshot = value + raise + """ + # No exception. + code = template.format(func='len') + self._check_in_scopes(code, {"value": ["ab"], "result": [2], "snapshot": None}) + # Handles exception. + code = template.format(func='int') + self._check_in_scopes(code, {"value": ["ab"], "result": None, "snapshot": ["ab"]}, + raises=ValueError) + + def test_comp_in_try_finally(self): + template = """ + value = ["ab"] + result = snapshot = None + try: + result = [{func}(value) for value in value] + finally: + snapshot = value + """ + # No exception. + code = template.format(func='len') + self._check_in_scopes(code, {"value": ["ab"], "result": [2], "snapshot": ["ab"]}) + # Handles exception. + code = template.format(func='int') + self._check_in_scopes(code, {"value": ["ab"], "result": None, "snapshot": ["ab"]}, + raises=ValueError) + + def test_exception_in_post_comp_call(self): + code = """ + value = [1, None] + try: + [v for v in value].sort() + except: + pass + """ + self._check_in_scopes(code, {"value": [1, None]}) + + def test_frame_locals(self): + code = """ + val = [sys._getframe().f_locals for a in [0]][0]["a"] + """ + import sys + self._check_in_scopes(code, {"val": 0}, ns={"sys": sys}) + + def _recursive_replace(self, maybe_code): + if not isinstance(maybe_code, types.CodeType): + return maybe_code + return maybe_code.replace(co_consts=tuple( + self._recursive_replace(c) for c in maybe_code.co_consts + )) + + def _replacing_exec(self, code_string, ns): + co = compile(code_string, "", "exec") + co = self._recursive_replace(co) + exec(co, ns) + + def test_code_replace(self): + code = """ + x = 3 + [x for x in (1, 2)] + dir() + y = [x] + """ + self._check_in_scopes(code, {"y": [3], "x": 3}) + self._check_in_scopes(code, {"y": [3], "x": 3}, exec_func=self._replacing_exec) + + def test_code_replace_extended_arg(self): + num_names = 300 + assignments = "; ".join(f"x{i} = {i}" for i in range(num_names)) + name_list = ", ".join(f"x{i}" for i in range(num_names)) + expected = { + "y": list(range(num_names)), + **{f"x{i}": i for i in range(num_names)} + } + code = f""" + {assignments} + [({name_list}) for {name_list} in (range(300),)] + dir() + y = [{name_list}] + """ + self._check_in_scopes(code, expected) + self._check_in_scopes(code, expected, exec_func=self._replacing_exec) + + def test_multiple_comprehension_name_reuse(self): + code = """ + [x for x in [1]] + y = [x for _ in [1]] + """ + self._check_in_scopes(code, {"y": [3]}, ns={"x": 3}) + + code = """ + x = 2 + [x for x in [1]] + y = [x for _ in [1]] + """ + self._check_in_scopes(code, {"x": 2, "y": [3]}, ns={"x": 3}, scopes=["class"]) + self._check_in_scopes(code, {"x": 2, "y": [2]}, ns={"x": 3}, scopes=["function", "module"]) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_exception_locations(self): + # The location of an exception raised from __init__ or + # __next__ should should be the iterator expression + + def init_raises(): + try: + [x for x in BrokenIter(init_raises=True)] + except Exception as e: + return e + + def next_raises(): + try: + [x for x in BrokenIter(next_raises=True)] + except Exception as e: + return e + + def iter_raises(): + try: + [x for x in BrokenIter(iter_raises=True)] + except Exception as e: + return e + + for func, expected in [(init_raises, "BrokenIter(init_raises=True)"), + (next_raises, "BrokenIter(next_raises=True)"), + (iter_raises, "BrokenIter(iter_raises=True)"), + ]: + with self.subTest(func): + exc = func() + f = traceback.extract_tb(exc.__traceback__)[0] + indent = 16 + co = func.__code__ + self.assertEqual(f.lineno, co.co_firstlineno + 2) + self.assertEqual(f.end_lineno, co.co_firstlineno + 2) + self.assertEqual(f.line[f.colno - indent : f.end_colno - indent], + expected) + __test__ = {'doctests' : doctests} def load_tests(loader, tests, pattern): diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py new file mode 100644 index 0000000000..a570d65f6c --- /dev/null +++ b/Lib/test/test_logging.py @@ -0,0 +1,7049 @@ +# Copyright 2001-2022 by Vinay Sajip. All Rights Reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Vinay Sajip +# not be used in advertising or publicity pertaining to distribution +# of the software without specific, written prior permission. +# VINAY SAJIP DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING +# ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL +# VINAY SAJIP BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR +# ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER +# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Test harness for the logging module. Run all tests. + +Copyright (C) 2001-2022 Vinay Sajip. All Rights Reserved. +""" +import logging +import logging.handlers +import logging.config + +import codecs +import configparser +import copy +import datetime +import pathlib +import pickle +import io +import itertools +import gc +import json +import os +import queue +import random +import re +import shutil +import socket +import struct +import sys +import tempfile +from test.support.script_helper import assert_python_ok, assert_python_failure +from test import support +from test.support import import_helper +from test.support import os_helper +from test.support import socket_helper +from test.support import threading_helper +from test.support import warnings_helper +from test.support import asyncore +from test.support import smtpd +from test.support.logging_helper import TestHandler +import textwrap +import threading +import asyncio +import time +import unittest +import warnings +import weakref + +from http.server import HTTPServer, BaseHTTPRequestHandler +from unittest.mock import patch +from urllib.parse import urlparse, parse_qs +from socketserver import (ThreadingUDPServer, DatagramRequestHandler, + ThreadingTCPServer, StreamRequestHandler) + +try: + import win32evtlog, win32evtlogutil, pywintypes +except ImportError: + win32evtlog = win32evtlogutil = pywintypes = None + +try: + import zlib +except ImportError: + pass + + +# gh-89363: Skip fork() test if Python is built with Address Sanitizer (ASAN) +# to work around a libasan race condition, dead lock in pthread_create(). +skip_if_asan_fork = unittest.skipIf( + support.HAVE_ASAN_FORK_BUG, + "libasan has a pthread_create() dead lock related to thread+fork") +skip_if_tsan_fork = unittest.skipIf( + support.check_sanitizer(thread=True), + "TSAN doesn't support threads after fork") + + +class BaseTest(unittest.TestCase): + + """Base class for logging tests.""" + + log_format = "%(name)s -> %(levelname)s: %(message)s" + expected_log_pat = r"^([\w.]+) -> (\w+): (\d+)$" + message_num = 0 + + def setUp(self): + """Setup the default logging stream to an internal StringIO instance, + so that we can examine log output as we want.""" + self._threading_key = threading_helper.threading_setup() + + logger_dict = logging.getLogger().manager.loggerDict + logging._acquireLock() + try: + self.saved_handlers = logging._handlers.copy() + self.saved_handler_list = logging._handlerList[:] + self.saved_loggers = saved_loggers = logger_dict.copy() + self.saved_name_to_level = logging._nameToLevel.copy() + self.saved_level_to_name = logging._levelToName.copy() + self.logger_states = logger_states = {} + for name in saved_loggers: + logger_states[name] = getattr(saved_loggers[name], + 'disabled', None) + finally: + logging._releaseLock() + + # Set two unused loggers + self.logger1 = logging.getLogger("\xab\xd7\xbb") + self.logger2 = logging.getLogger("\u013f\u00d6\u0047") + + self.root_logger = logging.getLogger("") + self.original_logging_level = self.root_logger.getEffectiveLevel() + + self.stream = io.StringIO() + self.root_logger.setLevel(logging.DEBUG) + self.root_hdlr = logging.StreamHandler(self.stream) + self.root_formatter = logging.Formatter(self.log_format) + self.root_hdlr.setFormatter(self.root_formatter) + if self.logger1.hasHandlers(): + hlist = self.logger1.handlers + self.root_logger.handlers + raise AssertionError('Unexpected handlers: %s' % hlist) + if self.logger2.hasHandlers(): + hlist = self.logger2.handlers + self.root_logger.handlers + raise AssertionError('Unexpected handlers: %s' % hlist) + self.root_logger.addHandler(self.root_hdlr) + self.assertTrue(self.logger1.hasHandlers()) + self.assertTrue(self.logger2.hasHandlers()) + + def tearDown(self): + """Remove our logging stream, and restore the original logging + level.""" + self.stream.close() + self.root_logger.removeHandler(self.root_hdlr) + while self.root_logger.handlers: + h = self.root_logger.handlers[0] + self.root_logger.removeHandler(h) + h.close() + self.root_logger.setLevel(self.original_logging_level) + logging._acquireLock() + try: + logging._levelToName.clear() + logging._levelToName.update(self.saved_level_to_name) + logging._nameToLevel.clear() + logging._nameToLevel.update(self.saved_name_to_level) + logging._handlers.clear() + logging._handlers.update(self.saved_handlers) + logging._handlerList[:] = self.saved_handler_list + manager = logging.getLogger().manager + manager.disable = 0 + loggerDict = manager.loggerDict + loggerDict.clear() + loggerDict.update(self.saved_loggers) + logger_states = self.logger_states + for name in self.logger_states: + if logger_states[name] is not None: + self.saved_loggers[name].disabled = logger_states[name] + finally: + logging._releaseLock() + + self.doCleanups() + threading_helper.threading_cleanup(*self._threading_key) + + def assert_log_lines(self, expected_values, stream=None, pat=None): + """Match the collected log lines against the regular expression + self.expected_log_pat, and compare the extracted group values to + the expected_values list of tuples.""" + stream = stream or self.stream + pat = re.compile(pat or self.expected_log_pat) + actual_lines = stream.getvalue().splitlines() + self.assertEqual(len(actual_lines), len(expected_values)) + for actual, expected in zip(actual_lines, expected_values): + match = pat.search(actual) + if not match: + self.fail("Log line does not match expected pattern:\n" + + actual) + self.assertEqual(tuple(match.groups()), expected) + s = stream.read() + if s: + self.fail("Remaining output at end of log stream:\n" + s) + + def next_message(self): + """Generate a message consisting solely of an auto-incrementing + integer.""" + self.message_num += 1 + return "%d" % self.message_num + + +class BuiltinLevelsTest(BaseTest): + """Test builtin levels and their inheritance.""" + + def test_flat(self): + # Logging levels in a flat logger namespace. + m = self.next_message + + ERR = logging.getLogger("ERR") + ERR.setLevel(logging.ERROR) + INF = logging.LoggerAdapter(logging.getLogger("INF"), {}) + INF.setLevel(logging.INFO) + DEB = logging.getLogger("DEB") + DEB.setLevel(logging.DEBUG) + + # These should log. + ERR.log(logging.CRITICAL, m()) + ERR.error(m()) + + INF.log(logging.CRITICAL, m()) + INF.error(m()) + INF.warning(m()) + INF.info(m()) + + DEB.log(logging.CRITICAL, m()) + DEB.error(m()) + DEB.warning(m()) + DEB.info(m()) + DEB.debug(m()) + + # These should not log. + ERR.warning(m()) + ERR.info(m()) + ERR.debug(m()) + + INF.debug(m()) + + self.assert_log_lines([ + ('ERR', 'CRITICAL', '1'), + ('ERR', 'ERROR', '2'), + ('INF', 'CRITICAL', '3'), + ('INF', 'ERROR', '4'), + ('INF', 'WARNING', '5'), + ('INF', 'INFO', '6'), + ('DEB', 'CRITICAL', '7'), + ('DEB', 'ERROR', '8'), + ('DEB', 'WARNING', '9'), + ('DEB', 'INFO', '10'), + ('DEB', 'DEBUG', '11'), + ]) + + def test_nested_explicit(self): + # Logging levels in a nested namespace, all explicitly set. + m = self.next_message + + INF = logging.getLogger("INF") + INF.setLevel(logging.INFO) + INF_ERR = logging.getLogger("INF.ERR") + INF_ERR.setLevel(logging.ERROR) + + # These should log. + INF_ERR.log(logging.CRITICAL, m()) + INF_ERR.error(m()) + + # These should not log. + INF_ERR.warning(m()) + INF_ERR.info(m()) + INF_ERR.debug(m()) + + self.assert_log_lines([ + ('INF.ERR', 'CRITICAL', '1'), + ('INF.ERR', 'ERROR', '2'), + ]) + + def test_nested_inherited(self): + # Logging levels in a nested namespace, inherited from parent loggers. + m = self.next_message + + INF = logging.getLogger("INF") + INF.setLevel(logging.INFO) + INF_ERR = logging.getLogger("INF.ERR") + INF_ERR.setLevel(logging.ERROR) + INF_UNDEF = logging.getLogger("INF.UNDEF") + INF_ERR_UNDEF = logging.getLogger("INF.ERR.UNDEF") + UNDEF = logging.getLogger("UNDEF") + + # These should log. + INF_UNDEF.log(logging.CRITICAL, m()) + INF_UNDEF.error(m()) + INF_UNDEF.warning(m()) + INF_UNDEF.info(m()) + INF_ERR_UNDEF.log(logging.CRITICAL, m()) + INF_ERR_UNDEF.error(m()) + + # These should not log. + INF_UNDEF.debug(m()) + INF_ERR_UNDEF.warning(m()) + INF_ERR_UNDEF.info(m()) + INF_ERR_UNDEF.debug(m()) + + self.assert_log_lines([ + ('INF.UNDEF', 'CRITICAL', '1'), + ('INF.UNDEF', 'ERROR', '2'), + ('INF.UNDEF', 'WARNING', '3'), + ('INF.UNDEF', 'INFO', '4'), + ('INF.ERR.UNDEF', 'CRITICAL', '5'), + ('INF.ERR.UNDEF', 'ERROR', '6'), + ]) + + def test_nested_with_virtual_parent(self): + # Logging levels when some parent does not exist yet. + m = self.next_message + + INF = logging.getLogger("INF") + GRANDCHILD = logging.getLogger("INF.BADPARENT.UNDEF") + CHILD = logging.getLogger("INF.BADPARENT") + INF.setLevel(logging.INFO) + + # These should log. + GRANDCHILD.log(logging.FATAL, m()) + GRANDCHILD.info(m()) + CHILD.log(logging.FATAL, m()) + CHILD.info(m()) + + # These should not log. + GRANDCHILD.debug(m()) + CHILD.debug(m()) + + self.assert_log_lines([ + ('INF.BADPARENT.UNDEF', 'CRITICAL', '1'), + ('INF.BADPARENT.UNDEF', 'INFO', '2'), + ('INF.BADPARENT', 'CRITICAL', '3'), + ('INF.BADPARENT', 'INFO', '4'), + ]) + + def test_regression_22386(self): + """See issue #22386 for more information.""" + self.assertEqual(logging.getLevelName('INFO'), logging.INFO) + self.assertEqual(logging.getLevelName(logging.INFO), 'INFO') + + def test_issue27935(self): + fatal = logging.getLevelName('FATAL') + self.assertEqual(fatal, logging.FATAL) + + def test_regression_29220(self): + """See issue #29220 for more information.""" + logging.addLevelName(logging.INFO, '') + self.addCleanup(logging.addLevelName, logging.INFO, 'INFO') + self.assertEqual(logging.getLevelName(logging.INFO), '') + self.assertEqual(logging.getLevelName(logging.NOTSET), 'NOTSET') + self.assertEqual(logging.getLevelName('NOTSET'), logging.NOTSET) + +class BasicFilterTest(BaseTest): + + """Test the bundled Filter class.""" + + def test_filter(self): + # Only messages satisfying the specified criteria pass through the + # filter. + filter_ = logging.Filter("spam.eggs") + handler = self.root_logger.handlers[0] + try: + handler.addFilter(filter_) + spam = logging.getLogger("spam") + spam_eggs = logging.getLogger("spam.eggs") + spam_eggs_fish = logging.getLogger("spam.eggs.fish") + spam_bakedbeans = logging.getLogger("spam.bakedbeans") + + spam.info(self.next_message()) + spam_eggs.info(self.next_message()) # Good. + spam_eggs_fish.info(self.next_message()) # Good. + spam_bakedbeans.info(self.next_message()) + + self.assert_log_lines([ + ('spam.eggs', 'INFO', '2'), + ('spam.eggs.fish', 'INFO', '3'), + ]) + finally: + handler.removeFilter(filter_) + + def test_callable_filter(self): + # Only messages satisfying the specified criteria pass through the + # filter. + + def filterfunc(record): + parts = record.name.split('.') + prefix = '.'.join(parts[:2]) + return prefix == 'spam.eggs' + + handler = self.root_logger.handlers[0] + try: + handler.addFilter(filterfunc) + spam = logging.getLogger("spam") + spam_eggs = logging.getLogger("spam.eggs") + spam_eggs_fish = logging.getLogger("spam.eggs.fish") + spam_bakedbeans = logging.getLogger("spam.bakedbeans") + + spam.info(self.next_message()) + spam_eggs.info(self.next_message()) # Good. + spam_eggs_fish.info(self.next_message()) # Good. + spam_bakedbeans.info(self.next_message()) + + self.assert_log_lines([ + ('spam.eggs', 'INFO', '2'), + ('spam.eggs.fish', 'INFO', '3'), + ]) + finally: + handler.removeFilter(filterfunc) + + def test_empty_filter(self): + f = logging.Filter() + r = logging.makeLogRecord({'name': 'spam.eggs'}) + self.assertTrue(f.filter(r)) + +# +# First, we define our levels. There can be as many as you want - the only +# limitations are that they should be integers, the lowest should be > 0 and +# larger values mean less information being logged. If you need specific +# level values which do not fit into these limitations, you can use a +# mapping dictionary to convert between your application levels and the +# logging system. +# +SILENT = 120 +TACITURN = 119 +TERSE = 118 +EFFUSIVE = 117 +SOCIABLE = 116 +VERBOSE = 115 +TALKATIVE = 114 +GARRULOUS = 113 +CHATTERBOX = 112 +BORING = 111 + +LEVEL_RANGE = range(BORING, SILENT + 1) + +# +# Next, we define names for our levels. You don't need to do this - in which +# case the system will use "Level n" to denote the text for the level. +# +my_logging_levels = { + SILENT : 'Silent', + TACITURN : 'Taciturn', + TERSE : 'Terse', + EFFUSIVE : 'Effusive', + SOCIABLE : 'Sociable', + VERBOSE : 'Verbose', + TALKATIVE : 'Talkative', + GARRULOUS : 'Garrulous', + CHATTERBOX : 'Chatterbox', + BORING : 'Boring', +} + +class GarrulousFilter(logging.Filter): + + """A filter which blocks garrulous messages.""" + + def filter(self, record): + return record.levelno != GARRULOUS + +class VerySpecificFilter(logging.Filter): + + """A filter which blocks sociable and taciturn messages.""" + + def filter(self, record): + return record.levelno not in [SOCIABLE, TACITURN] + + +class CustomLevelsAndFiltersTest(BaseTest): + + """Test various filtering possibilities with custom logging levels.""" + + # Skip the logger name group. + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" + + def setUp(self): + BaseTest.setUp(self) + for k, v in my_logging_levels.items(): + logging.addLevelName(k, v) + + def log_at_all_levels(self, logger): + for lvl in LEVEL_RANGE: + logger.log(lvl, self.next_message()) + + def test_handler_filter_replaces_record(self): + def replace_message(record: logging.LogRecord): + record = copy.copy(record) + record.msg = "new message!" + return record + + # Set up a logging hierarchy such that "child" and it's handler + # (and thus `replace_message()`) always get called before + # propagating up to "parent". + # Then we can confirm that `replace_message()` was able to + # replace the log record without having a side effect on + # other loggers or handlers. + parent = logging.getLogger("parent") + child = logging.getLogger("parent.child") + stream_1 = io.StringIO() + stream_2 = io.StringIO() + handler_1 = logging.StreamHandler(stream_1) + handler_2 = logging.StreamHandler(stream_2) + handler_2.addFilter(replace_message) + parent.addHandler(handler_1) + child.addHandler(handler_2) + + child.info("original message") + handler_1.flush() + handler_2.flush() + self.assertEqual(stream_1.getvalue(), "original message\n") + self.assertEqual(stream_2.getvalue(), "new message!\n") + + def test_logging_filter_replaces_record(self): + records = set() + + class RecordingFilter(logging.Filter): + def filter(self, record: logging.LogRecord): + records.add(id(record)) + return copy.copy(record) + + logger = logging.getLogger("logger") + logger.setLevel(logging.INFO) + logger.addFilter(RecordingFilter()) + logger.addFilter(RecordingFilter()) + + logger.info("msg") + + self.assertEqual(2, len(records)) + + def test_logger_filter(self): + # Filter at logger level. + self.root_logger.setLevel(VERBOSE) + # Levels >= 'Verbose' are good. + self.log_at_all_levels(self.root_logger) + self.assert_log_lines([ + ('Verbose', '5'), + ('Sociable', '6'), + ('Effusive', '7'), + ('Terse', '8'), + ('Taciturn', '9'), + ('Silent', '10'), + ]) + + def test_handler_filter(self): + # Filter at handler level. + self.root_logger.handlers[0].setLevel(SOCIABLE) + try: + # Levels >= 'Sociable' are good. + self.log_at_all_levels(self.root_logger) + self.assert_log_lines([ + ('Sociable', '6'), + ('Effusive', '7'), + ('Terse', '8'), + ('Taciturn', '9'), + ('Silent', '10'), + ]) + finally: + self.root_logger.handlers[0].setLevel(logging.NOTSET) + + def test_specific_filters(self): + # Set a specific filter object on the handler, and then add another + # filter object on the logger itself. + handler = self.root_logger.handlers[0] + specific_filter = None + garr = GarrulousFilter() + handler.addFilter(garr) + try: + self.log_at_all_levels(self.root_logger) + first_lines = [ + # Notice how 'Garrulous' is missing + ('Boring', '1'), + ('Chatterbox', '2'), + ('Talkative', '4'), + ('Verbose', '5'), + ('Sociable', '6'), + ('Effusive', '7'), + ('Terse', '8'), + ('Taciturn', '9'), + ('Silent', '10'), + ] + self.assert_log_lines(first_lines) + + specific_filter = VerySpecificFilter() + self.root_logger.addFilter(specific_filter) + self.log_at_all_levels(self.root_logger) + self.assert_log_lines(first_lines + [ + # Not only 'Garrulous' is still missing, but also 'Sociable' + # and 'Taciturn' + ('Boring', '11'), + ('Chatterbox', '12'), + ('Talkative', '14'), + ('Verbose', '15'), + ('Effusive', '17'), + ('Terse', '18'), + ('Silent', '20'), + ]) + finally: + if specific_filter: + self.root_logger.removeFilter(specific_filter) + handler.removeFilter(garr) + + +def make_temp_file(*args, **kwargs): + fd, fn = tempfile.mkstemp(*args, **kwargs) + os.close(fd) + return fn + + +class HandlerTest(BaseTest): + def test_name(self): + h = logging.Handler() + h.name = 'generic' + self.assertEqual(h.name, 'generic') + h.name = 'anothergeneric' + self.assertEqual(h.name, 'anothergeneric') + self.assertRaises(NotImplementedError, h.emit, None) + + def test_builtin_handlers(self): + # We can't actually *use* too many handlers in the tests, + # but we can try instantiating them with various options + if sys.platform in ('linux', 'darwin'): + for existing in (True, False): + fn = make_temp_file() + if not existing: + os.unlink(fn) + h = logging.handlers.WatchedFileHandler(fn, encoding='utf-8', delay=True) + if existing: + dev, ino = h.dev, h.ino + self.assertEqual(dev, -1) + self.assertEqual(ino, -1) + r = logging.makeLogRecord({'msg': 'Test'}) + h.handle(r) + # Now remove the file. + os.unlink(fn) + self.assertFalse(os.path.exists(fn)) + # The next call should recreate the file. + h.handle(r) + self.assertTrue(os.path.exists(fn)) + else: + self.assertEqual(h.dev, -1) + self.assertEqual(h.ino, -1) + h.close() + if existing: + os.unlink(fn) + if sys.platform == 'darwin': + sockname = '/var/run/syslog' + else: + sockname = '/dev/log' + try: + h = logging.handlers.SysLogHandler(sockname) + self.assertEqual(h.facility, h.LOG_USER) + self.assertTrue(h.unixsocket) + h.close() + except OSError: # syslogd might not be available + pass + for method in ('GET', 'POST', 'PUT'): + if method == 'PUT': + self.assertRaises(ValueError, logging.handlers.HTTPHandler, + 'localhost', '/log', method) + else: + h = logging.handlers.HTTPHandler('localhost', '/log', method) + h.close() + h = logging.handlers.BufferingHandler(0) + r = logging.makeLogRecord({}) + self.assertTrue(h.shouldFlush(r)) + h.close() + h = logging.handlers.BufferingHandler(1) + self.assertFalse(h.shouldFlush(r)) + h.close() + + def test_pathlike_objects(self): + """ + Test that path-like objects are accepted as filename arguments to handlers. + + See Issue #27493. + """ + fn = make_temp_file() + os.unlink(fn) + pfn = os_helper.FakePath(fn) + cases = ( + (logging.FileHandler, (pfn, 'w')), + (logging.handlers.RotatingFileHandler, (pfn, 'a')), + (logging.handlers.TimedRotatingFileHandler, (pfn, 'h')), + ) + if sys.platform in ('linux', 'darwin'): + cases += ((logging.handlers.WatchedFileHandler, (pfn, 'w')),) + for cls, args in cases: + h = cls(*args, encoding="utf-8") + self.assertTrue(os.path.exists(fn)) + h.close() + os.unlink(fn) + + @unittest.skipIf(os.name == 'nt', 'WatchedFileHandler not appropriate for Windows.') + @unittest.skipIf( + support.is_emscripten, "Emscripten cannot fstat unlinked files." + ) + @threading_helper.requires_working_threading() + @support.requires_resource('walltime') + def test_race(self): + # Issue #14632 refers. + def remove_loop(fname, tries): + for _ in range(tries): + try: + os.unlink(fname) + self.deletion_time = time.time() + except OSError: + pass + time.sleep(0.004 * random.randint(0, 4)) + + del_count = 500 + log_count = 500 + + self.handle_time = None + self.deletion_time = None + + for delay in (False, True): + fn = make_temp_file('.log', 'test_logging-3-') + remover = threading.Thread(target=remove_loop, args=(fn, del_count)) + remover.daemon = True + remover.start() + h = logging.handlers.WatchedFileHandler(fn, encoding='utf-8', delay=delay) + f = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s') + h.setFormatter(f) + try: + for _ in range(log_count): + time.sleep(0.005) + r = logging.makeLogRecord({'msg': 'testing' }) + try: + self.handle_time = time.time() + h.handle(r) + except Exception: + print('Deleted at %s, ' + 'opened at %s' % (self.deletion_time, + self.handle_time)) + raise + finally: + remover.join() + h.close() + if os.path.exists(fn): + os.unlink(fn) + + # The implementation relies on os.register_at_fork existing, but we test + # based on os.fork existing because that is what users and this test use. + # This helps ensure that when fork exists (the important concept) that the + # register_at_fork mechanism is also present and used. + @support.requires_fork() + @threading_helper.requires_working_threading() + @skip_if_asan_fork + @skip_if_tsan_fork + def test_post_fork_child_no_deadlock(self): + """Ensure child logging locks are not held; bpo-6721 & bpo-36533.""" + class _OurHandler(logging.Handler): + def __init__(self): + super().__init__() + self.sub_handler = logging.StreamHandler( + stream=open('/dev/null', 'wt', encoding='utf-8')) + + def emit(self, record): + self.sub_handler.acquire() + try: + self.sub_handler.emit(record) + finally: + self.sub_handler.release() + + self.assertEqual(len(logging._handlers), 0) + refed_h = _OurHandler() + self.addCleanup(refed_h.sub_handler.stream.close) + refed_h.name = 'because we need at least one for this test' + self.assertGreater(len(logging._handlers), 0) + self.assertGreater(len(logging._at_fork_reinit_lock_weakset), 1) + test_logger = logging.getLogger('test_post_fork_child_no_deadlock') + test_logger.addHandler(refed_h) + test_logger.setLevel(logging.DEBUG) + + locks_held__ready_to_fork = threading.Event() + fork_happened__release_locks_and_end_thread = threading.Event() + + def lock_holder_thread_fn(): + logging._acquireLock() + try: + refed_h.acquire() + try: + # Tell the main thread to do the fork. + locks_held__ready_to_fork.set() + + # If the deadlock bug exists, the fork will happen + # without dealing with the locks we hold, deadlocking + # the child. + + # Wait for a successful fork or an unreasonable amount of + # time before releasing our locks. To avoid a timing based + # test we'd need communication from os.fork() as to when it + # has actually happened. Given this is a regression test + # for a fixed issue, potentially less reliably detecting + # regression via timing is acceptable for simplicity. + # The test will always take at least this long. :( + fork_happened__release_locks_and_end_thread.wait(0.5) + finally: + refed_h.release() + finally: + logging._releaseLock() + + lock_holder_thread = threading.Thread( + target=lock_holder_thread_fn, + name='test_post_fork_child_no_deadlock lock holder') + lock_holder_thread.start() + + locks_held__ready_to_fork.wait() + pid = os.fork() + if pid == 0: + # Child process + try: + test_logger.info(r'Child process did not deadlock. \o/') + finally: + os._exit(0) + else: + # Parent process + test_logger.info(r'Parent process returned from fork. \o/') + fork_happened__release_locks_and_end_thread.set() + lock_holder_thread.join() + + support.wait_process(pid, exitcode=0) + + +class BadStream(object): + def write(self, data): + raise RuntimeError('deliberate mistake') + +class TestStreamHandler(logging.StreamHandler): + def handleError(self, record): + self.error_record = record + +class StreamWithIntName(object): + level = logging.NOTSET + name = 2 + +class StreamHandlerTest(BaseTest): + def test_error_handling(self): + h = TestStreamHandler(BadStream()) + r = logging.makeLogRecord({}) + old_raise = logging.raiseExceptions + + try: + h.handle(r) + self.assertIs(h.error_record, r) + + h = logging.StreamHandler(BadStream()) + with support.captured_stderr() as stderr: + h.handle(r) + msg = '\nRuntimeError: deliberate mistake\n' + self.assertIn(msg, stderr.getvalue()) + + logging.raiseExceptions = False + with support.captured_stderr() as stderr: + h.handle(r) + self.assertEqual('', stderr.getvalue()) + finally: + logging.raiseExceptions = old_raise + + def test_stream_setting(self): + """ + Test setting the handler's stream + """ + h = logging.StreamHandler() + stream = io.StringIO() + old = h.setStream(stream) + self.assertIs(old, sys.stderr) + actual = h.setStream(old) + self.assertIs(actual, stream) + # test that setting to existing value returns None + actual = h.setStream(old) + self.assertIsNone(actual) + + def test_can_represent_stream_with_int_name(self): + h = logging.StreamHandler(StreamWithIntName()) + self.assertEqual(repr(h), '') + +# -- The following section could be moved into a server_helper.py module +# -- if it proves to be of wider utility than just test_logging + +class TestSMTPServer(smtpd.SMTPServer): + """ + This class implements a test SMTP server. + + :param addr: A (host, port) tuple which the server listens on. + You can specify a port value of zero: the server's + *port* attribute will hold the actual port number + used, which can be used in client connections. + :param handler: A callable which will be called to process + incoming messages. The handler will be passed + the client address tuple, who the message is from, + a list of recipients and the message data. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + :param sockmap: A dictionary which will be used to hold + :class:`asyncore.dispatcher` instances used by + :func:`asyncore.loop`. This avoids changing the + :mod:`asyncore` module's global state. + """ + + def __init__(self, addr, handler, poll_interval, sockmap): + smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, + decode_data=True) + self.port = self.socket.getsockname()[1] + self._handler = handler + self._thread = None + self._quit = False + self.poll_interval = poll_interval + + def process_message(self, peer, mailfrom, rcpttos, data): + """ + Delegates to the handler passed in to the server's constructor. + + Typically, this will be a test case method. + :param peer: The client (host, port) tuple. + :param mailfrom: The address of the sender. + :param rcpttos: The addresses of the recipients. + :param data: The message. + """ + self._handler(peer, mailfrom, rcpttos, data) + + def start(self): + """ + Start the server running on a separate daemon thread. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.daemon = True + t.start() + + def serve_forever(self, poll_interval): + """ + Run the :mod:`asyncore` loop until normal termination + conditions arise. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + """ + while not self._quit: + asyncore.loop(poll_interval, map=self._map, count=1) + + def stop(self): + """ + Stop the thread by closing the server instance. + Wait for the server thread to terminate. + """ + self._quit = True + threading_helper.join_thread(self._thread) + self._thread = None + self.close() + asyncore.close_all(map=self._map, ignore_all=True) + + +class ControlMixin(object): + """ + This mixin is used to start a server on a separate thread, and + shut it down programmatically. Request handling is simplified - instead + of needing to derive a suitable RequestHandler subclass, you just + provide a callable which will be passed each received request to be + processed. + + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. This handler is called on the + server thread, effectively meaning that requests are + processed serially. While not quite web scale ;-), + this should be fine for testing applications. + :param poll_interval: The polling interval in seconds. + """ + def __init__(self, handler, poll_interval): + self._thread = None + self.poll_interval = poll_interval + self._handler = handler + self.ready = threading.Event() + + def start(self): + """ + Create a daemon thread to run the server, and start it. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.daemon = True + t.start() + + def serve_forever(self, poll_interval): + """ + Run the server. Set the ready flag before entering the + service loop. + """ + self.ready.set() + super(ControlMixin, self).serve_forever(poll_interval) + + def stop(self): + """ + Tell the server thread to stop, and wait for it to do so. + """ + self.shutdown() + if self._thread is not None: + threading_helper.join_thread(self._thread) + self._thread = None + self.server_close() + self.ready.clear() + +class TestHTTPServer(ControlMixin, HTTPServer): + """ + An HTTP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval in seconds. + :param log: Pass ``True`` to enable log messages. + """ + def __init__(self, addr, handler, poll_interval=0.5, + log=False, sslctx=None): + class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): + def __getattr__(self, name, default=None): + if name.startswith('do_'): + return self.process_request + raise AttributeError(name) + + def process_request(self): + self.server._handler(self) + + def log_message(self, format, *args): + if log: + super(DelegatingHTTPRequestHandler, + self).log_message(format, *args) + HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler) + ControlMixin.__init__(self, handler, poll_interval) + self.sslctx = sslctx + + def get_request(self): + try: + sock, addr = self.socket.accept() + if self.sslctx: + sock = self.sslctx.wrap_socket(sock, server_side=True) + except OSError as e: + # socket errors are silenced by the caller, print them here + sys.stderr.write("Got an error:\n%s\n" % e) + raise + return sock, addr + +class TestTCPServer(ControlMixin, ThreadingTCPServer): + """ + A TCP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a single + parameter - the request - in order to process the request. + :param poll_interval: The polling interval in seconds. + :bind_and_activate: If True (the default), binds the server and starts it + listening. If False, you need to call + :meth:`server_bind` and :meth:`server_activate` at + some later time before calling :meth:`start`, so that + the server will set up the socket and listen on it. + """ + + allow_reuse_address = True + + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingTCPRequestHandler(StreamRequestHandler): + + def handle(self): + self.server._handler(self) + ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + + def server_bind(self): + super(TestTCPServer, self).server_bind() + self.port = self.socket.getsockname()[1] + +class TestUDPServer(ControlMixin, ThreadingUDPServer): + """ + A UDP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval for shutdown requests, + in seconds. + :bind_and_activate: If True (the default), binds the server and + starts it listening. If False, you need to + call :meth:`server_bind` and + :meth:`server_activate` at some later time + before calling :meth:`start`, so that the server will + set up the socket and listen on it. + """ + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingUDPRequestHandler(DatagramRequestHandler): + + def handle(self): + self.server._handler(self) + + def finish(self): + data = self.wfile.getvalue() + if data: + try: + super(DelegatingUDPRequestHandler, self).finish() + except OSError: + if not self.server._closed: + raise + + ThreadingUDPServer.__init__(self, addr, + DelegatingUDPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + self._closed = False + + def server_bind(self): + super(TestUDPServer, self).server_bind() + self.port = self.socket.getsockname()[1] + + def server_close(self): + super(TestUDPServer, self).server_close() + self._closed = True + +if hasattr(socket, "AF_UNIX"): + class TestUnixStreamServer(TestTCPServer): + address_family = socket.AF_UNIX + + class TestUnixDatagramServer(TestUDPServer): + address_family = socket.AF_UNIX + +# - end of server_helper section + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class SMTPHandlerTest(BaseTest): + # bpo-14314, bpo-19665, bpo-34092: don't wait forever + TIMEOUT = support.LONG_TIMEOUT + + # TODO: RUSTPYTHON + @unittest.skip(reason="Hangs RustPython") + def test_basic(self): + sockmap = {} + server = TestSMTPServer((socket_helper.HOST, 0), self.process_message, 0.001, + sockmap) + server.start() + addr = (socket_helper.HOST, server.port) + h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log', + timeout=self.TIMEOUT) + self.assertEqual(h.toaddrs, ['you']) + self.messages = [] + r = logging.makeLogRecord({'msg': 'Hello \u2713'}) + self.handled = threading.Event() + h.handle(r) + self.handled.wait(self.TIMEOUT) + server.stop() + self.assertTrue(self.handled.is_set()) + self.assertEqual(len(self.messages), 1) + peer, mailfrom, rcpttos, data = self.messages[0] + self.assertEqual(mailfrom, 'me') + self.assertEqual(rcpttos, ['you']) + self.assertIn('\nSubject: Log\n', data) + self.assertTrue(data.endswith('\n\nHello \u2713')) + h.close() + + def process_message(self, *args): + self.messages.append(args) + self.handled.set() + +class MemoryHandlerTest(BaseTest): + + """Tests for the MemoryHandler.""" + + # Do not bother with a logger name group. + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" + + def setUp(self): + BaseTest.setUp(self) + self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING, + self.root_hdlr) + self.mem_logger = logging.getLogger('mem') + self.mem_logger.propagate = 0 + self.mem_logger.addHandler(self.mem_hdlr) + + def tearDown(self): + self.mem_hdlr.close() + BaseTest.tearDown(self) + + def test_flush(self): + # The memory handler flushes to its target handler based on specific + # criteria (message count and message level). + self.mem_logger.debug(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.info(self.next_message()) + self.assert_log_lines([]) + # This will flush because the level is >= logging.WARNING + self.mem_logger.warning(self.next_message()) + lines = [ + ('DEBUG', '1'), + ('INFO', '2'), + ('WARNING', '3'), + ] + self.assert_log_lines(lines) + for n in (4, 14): + for i in range(9): + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) + # This will flush because it's the 10th message since the last + # flush. + self.mem_logger.debug(self.next_message()) + lines = lines + [('DEBUG', str(i)) for i in range(n, n + 10)] + self.assert_log_lines(lines) + + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) + + def test_flush_on_close(self): + """ + Test that the flush-on-close configuration works as expected. + """ + self.mem_logger.debug(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.info(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.removeHandler(self.mem_hdlr) + # Default behaviour is to flush on close. Check that it happens. + self.mem_hdlr.close() + lines = [ + ('DEBUG', '1'), + ('INFO', '2'), + ] + self.assert_log_lines(lines) + # Now configure for flushing not to be done on close. + self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING, + self.root_hdlr, + False) + self.mem_logger.addHandler(self.mem_hdlr) + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) # no change + self.mem_logger.info(self.next_message()) + self.assert_log_lines(lines) # no change + self.mem_logger.removeHandler(self.mem_hdlr) + self.mem_hdlr.close() + # assert that no new lines have been added + self.assert_log_lines(lines) # no change + + def test_shutdown_flush_on_close(self): + """ + Test that the flush-on-close configuration is respected by the + shutdown method. + """ + self.mem_logger.debug(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.info(self.next_message()) + self.assert_log_lines([]) + # Default behaviour is to flush on close. Check that it happens. + logging.shutdown(handlerList=[logging.weakref.ref(self.mem_hdlr)]) + lines = [ + ('DEBUG', '1'), + ('INFO', '2'), + ] + self.assert_log_lines(lines) + # Now configure for flushing not to be done on close. + self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING, + self.root_hdlr, + False) + self.mem_logger.addHandler(self.mem_hdlr) + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) # no change + self.mem_logger.info(self.next_message()) + self.assert_log_lines(lines) # no change + # assert that no new lines have been added after shutdown + logging.shutdown(handlerList=[logging.weakref.ref(self.mem_hdlr)]) + self.assert_log_lines(lines) # no change + + @threading_helper.requires_working_threading() + def test_race_between_set_target_and_flush(self): + class MockRaceConditionHandler: + def __init__(self, mem_hdlr): + self.mem_hdlr = mem_hdlr + self.threads = [] + + def removeTarget(self): + self.mem_hdlr.setTarget(None) + + def handle(self, msg): + thread = threading.Thread(target=self.removeTarget) + self.threads.append(thread) + thread.start() + + target = MockRaceConditionHandler(self.mem_hdlr) + try: + self.mem_hdlr.setTarget(target) + + for _ in range(10): + time.sleep(0.005) + self.mem_logger.info("not flushed") + self.mem_logger.warning("flushed") + finally: + for thread in target.threads: + threading_helper.join_thread(thread) + + +class ExceptionFormatter(logging.Formatter): + """A special exception formatter.""" + def formatException(self, ei): + return "Got a [%s]" % ei[0].__name__ + +def closeFileHandler(h, fn): + h.close() + os.remove(fn) + +class ConfigFileTest(BaseTest): + + """Reading logging config from a .ini-style config file.""" + + check_no_resource_warning = warnings_helper.check_no_resource_warning + expected_log_pat = r"^(\w+) \+\+ (\w+)$" + + # config0 is a standard configuration. + config0 = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config1 adds a little to the standard configuration. + config1 = """ + [loggers] + keys=root,parser + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers= + + [logger_parser] + level=DEBUG + handlers=hand1 + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config1a moves the handler to the root. + config1a = """ + [loggers] + keys=root,parser + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [logger_parser] + level=DEBUG + handlers= + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config2 has a subtle configuration error that should be reported + config2 = config1.replace("sys.stdout", "sys.stbout") + + # config3 has a less subtle configuration error + config3 = config1.replace("formatter=form1", "formatter=misspelled_name") + + # config4 specifies a custom formatter class to be loaded + config4 = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=NOTSET + handlers=hand1 + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + class=""" + __name__ + """.ExceptionFormatter + format=%(levelname)s:%(name)s:%(message)s + datefmt= + """ + + # config5 specifies a custom handler class to be loaded + config5 = config1.replace('class=StreamHandler', 'class=logging.StreamHandler') + + # config6 uses ', ' delimiters in the handlers and formatters sections + config6 = """ + [loggers] + keys=root,parser + + [handlers] + keys=hand1, hand2 + + [formatters] + keys=form1, form2 + + [logger_root] + level=WARNING + handlers= + + [logger_parser] + level=DEBUG + handlers=hand1 + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [handler_hand2] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stderr,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + + [formatter_form2] + format=%(message)s + datefmt= + """ + + # config7 adds a compiler logger, and uses kwargs instead of args. + config7 = """ + [loggers] + keys=root,parser,compiler + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [logger_compiler] + level=DEBUG + handlers= + propagate=1 + qualname=compiler + + [logger_parser] + level=DEBUG + handlers= + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + kwargs={'stream': sys.stdout,} + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config 8, check for resource warning + config8 = r""" + [loggers] + keys=root + + [handlers] + keys=file + + [formatters] + keys= + + [logger_root] + level=DEBUG + handlers=file + + [handler_file] + class=FileHandler + level=DEBUG + args=("{tempfile}",) + kwargs={{"encoding": "utf-8"}} + """ + + + config9 = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(message)s ++ %(customfield)s + defaults={"customfield": "defaultvalue"} + """ + + disable_test = """ + [loggers] + keys=root + + [handlers] + keys=screen + + [formatters] + keys= + + [logger_root] + level=DEBUG + handlers=screen + + [handler_screen] + level=DEBUG + class=StreamHandler + args=(sys.stdout,) + formatter= + """ + + def apply_config(self, conf, **kwargs): + file = io.StringIO(textwrap.dedent(conf)) + logging.config.fileConfig(file, encoding="utf-8", **kwargs) + + def test_config0_ok(self): + # A simple config file which overrides the default settings. + with support.captured_stdout() as output: + self.apply_config(self.config0) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config0_using_cp_ok(self): + # A simple config file which overrides the default settings. + with support.captured_stdout() as output: + file = io.StringIO(textwrap.dedent(self.config0)) + cp = configparser.ConfigParser() + cp.read_file(file) + logging.config.fileConfig(cp) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config1_ok(self, config=config1): + # A config file defining a sub-parser as well. + with support.captured_stdout() as output: + self.apply_config(config) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config2_failure(self): + # A simple config file which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2) + + def test_config3_failure(self): + # A simple config file which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config3) + + def test_config4_ok(self): + # A config file specifying a custom formatter class. + with support.captured_stdout() as output: + self.apply_config(self.config4) + logger = logging.getLogger() + try: + raise RuntimeError() + except RuntimeError: + logging.exception("just testing") + sys.stdout.seek(0) + self.assertEqual(output.getvalue(), + "ERROR:root:just testing\nGot a [RuntimeError]\n") + # Original logger output is empty + self.assert_log_lines([]) + + def test_config5_ok(self): + self.test_config1_ok(config=self.config5) + + def test_config6_ok(self): + self.test_config1_ok(config=self.config6) + + def test_config7_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1a) + logger = logging.getLogger("compiler.parser") + # See issue #11424. compiler-hyphenated sorts + # between compiler and compiler.xyz and this + # was preventing compiler.xyz from being included + # in the child loggers of compiler because of an + # overzealous loop termination condition. + hyphenated = logging.getLogger('compiler-hyphenated') + # All will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ('CRITICAL', '3'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config7) + logger = logging.getLogger("compiler.parser") + self.assertFalse(logger.disabled) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + # Will not appear + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '4'), + ('ERROR', '5'), + ('INFO', '6'), + ('ERROR', '7'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config8_ok(self): + + with self.check_no_resource_warning(): + fn = make_temp_file(".log", "test_logging-X-") + + # Replace single backslash with double backslash in windows + # to avoid unicode error during string formatting + if os.name == "nt": + fn = fn.replace("\\", "\\\\") + + config8 = self.config8.format(tempfile=fn) + + self.apply_config(config8) + self.apply_config(config8) + + handler = logging.root.handlers[0] + self.addCleanup(closeFileHandler, handler, fn) + + def test_config9_ok(self): + self.apply_config(self.config9) + formatter = logging.root.handlers[0].formatter + result = formatter.format(logging.makeLogRecord({'msg': 'test'})) + self.assertEqual(result, 'test ++ defaultvalue') + result = formatter.format(logging.makeLogRecord( + {'msg': 'test', 'customfield': "customvalue"})) + self.assertEqual(result, 'test ++ customvalue') + + + def test_logger_disabling(self): + self.apply_config(self.disable_test) + logger = logging.getLogger('some_pristine_logger') + self.assertFalse(logger.disabled) + self.apply_config(self.disable_test) + self.assertTrue(logger.disabled) + self.apply_config(self.disable_test, disable_existing_loggers=False) + self.assertFalse(logger.disabled) + + def test_config_set_handler_names(self): + test_config = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + handlers=hand1 + + [handler_hand1] + class=StreamHandler + formatter=form1 + + [formatter_form1] + format=%(levelname)s ++ %(message)s + """ + self.apply_config(test_config) + self.assertEqual(logging.getLogger().handlers[0].name, 'hand1') + + def test_exception_if_confg_file_is_invalid(self): + test_config = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + handlers=hand1 + + [handler_hand1] + class=StreamHandler + formatter=form1 + + [formatter_form1] + format=%(levelname)s ++ %(message)s + + prince + """ + + file = io.StringIO(textwrap.dedent(test_config)) + self.assertRaises(RuntimeError, logging.config.fileConfig, file) + + def test_exception_if_confg_file_is_empty(self): + fd, fn = tempfile.mkstemp(prefix='test_empty_', suffix='.ini') + os.close(fd) + self.assertRaises(RuntimeError, logging.config.fileConfig, fn) + os.remove(fn) + + def test_exception_if_config_file_does_not_exist(self): + self.assertRaises(FileNotFoundError, logging.config.fileConfig, 'filenotfound') + + def test_defaults_do_no_interpolation(self): + """bpo-33802 defaults should not get interpolated""" + ini = textwrap.dedent(""" + [formatters] + keys=default + + [formatter_default] + + [handlers] + keys=console + + [handler_console] + class=logging.StreamHandler + args=tuple() + + [loggers] + keys=root + + [logger_root] + formatter=default + handlers=console + """).strip() + fd, fn = tempfile.mkstemp(prefix='test_logging_', suffix='.ini') + try: + os.write(fd, ini.encode('ascii')) + os.close(fd) + logging.config.fileConfig( + fn, + encoding="utf-8", + defaults=dict( + version=1, + disable_existing_loggers=False, + formatters={ + "generic": { + "format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s", + "datefmt": "[%Y-%m-%d %H:%M:%S %z]", + "class": "logging.Formatter" + }, + }, + ) + ) + finally: + os.unlink(fn) + + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class SocketHandlerTest(BaseTest): + + """Test for SocketHandler objects.""" + + server_class = TestTCPServer + address = ('localhost', 0) + + def setUp(self): + """Set up a TCP server to receive log messages, and a SocketHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + # Issue #29177: deal with errors that happen during setup + self.server = self.sock_hdlr = self.server_exception = None + try: + self.server = server = self.server_class(self.address, + self.handle_socket, 0.01) + server.start() + # Uncomment next line to test error recovery in setUp() + # raise OSError('dummy error raised') + except OSError as e: + self.server_exception = e + return + server.ready.wait() + hcls = logging.handlers.SocketHandler + if isinstance(server.server_address, tuple): + self.sock_hdlr = hcls('localhost', server.port) + else: + self.sock_hdlr = hcls(server.server_address, None) + self.log_output = '' + self.root_logger.removeHandler(self.root_logger.handlers[0]) + self.root_logger.addHandler(self.sock_hdlr) + self.handled = threading.Semaphore(0) + + def tearDown(self): + """Shutdown the TCP server.""" + try: + if self.sock_hdlr: + self.root_logger.removeHandler(self.sock_hdlr) + self.sock_hdlr.close() + if self.server: + self.server.stop() + finally: + BaseTest.tearDown(self) + + def handle_socket(self, request): + conn = request.connection + while True: + chunk = conn.recv(4) + if len(chunk) < 4: + break + slen = struct.unpack(">L", chunk)[0] + chunk = conn.recv(slen) + while len(chunk) < slen: + chunk = chunk + conn.recv(slen - len(chunk)) + obj = pickle.loads(chunk) + record = logging.makeLogRecord(obj) + self.log_output += record.msg + '\n' + self.handled.release() + + def test_output(self): + # The log message sent to the SocketHandler is properly received. + if self.server_exception: + self.skipTest(self.server_exception) + logger = logging.getLogger("tcp") + logger.error("spam") + self.handled.acquire() + logger.debug("eggs") + self.handled.acquire() + self.assertEqual(self.log_output, "spam\neggs\n") + + def test_noserver(self): + if self.server_exception: + self.skipTest(self.server_exception) + # Avoid timing-related failures due to SocketHandler's own hard-wired + # one-second timeout on socket.create_connection() (issue #16264). + self.sock_hdlr.retryStart = 2.5 + # Kill the server + self.server.stop() + # The logging call should try to connect, which should fail + try: + raise RuntimeError('Deliberate mistake') + except RuntimeError: + self.root_logger.exception('Never sent') + self.root_logger.error('Never sent, either') + now = time.time() + self.assertGreater(self.sock_hdlr.retryTime, now) + time.sleep(self.sock_hdlr.retryTime - now + 0.001) + self.root_logger.error('Nor this') + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") +class UnixSocketHandlerTest(SocketHandlerTest): + + """Test for SocketHandler with unix sockets.""" + + if hasattr(socket, "AF_UNIX"): + server_class = TestUnixStreamServer + + def setUp(self): + # override the definition in the base class + self.address = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.address) + SocketHandlerTest.setUp(self) + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class DatagramHandlerTest(BaseTest): + + """Test for DatagramHandler.""" + + server_class = TestUDPServer + address = ('localhost', 0) + + def setUp(self): + """Set up a UDP server to receive log messages, and a DatagramHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + # Issue #29177: deal with errors that happen during setup + self.server = self.sock_hdlr = self.server_exception = None + try: + self.server = server = self.server_class(self.address, + self.handle_datagram, 0.01) + server.start() + # Uncomment next line to test error recovery in setUp() + # raise OSError('dummy error raised') + except OSError as e: + self.server_exception = e + return + server.ready.wait() + hcls = logging.handlers.DatagramHandler + if isinstance(server.server_address, tuple): + self.sock_hdlr = hcls('localhost', server.port) + else: + self.sock_hdlr = hcls(server.server_address, None) + self.log_output = '' + self.root_logger.removeHandler(self.root_logger.handlers[0]) + self.root_logger.addHandler(self.sock_hdlr) + self.handled = threading.Event() + + def tearDown(self): + """Shutdown the UDP server.""" + try: + if self.server: + self.server.stop() + if self.sock_hdlr: + self.root_logger.removeHandler(self.sock_hdlr) + self.sock_hdlr.close() + finally: + BaseTest.tearDown(self) + + def handle_datagram(self, request): + slen = struct.pack('>L', 0) # length of prefix + packet = request.packet[len(slen):] + obj = pickle.loads(packet) + record = logging.makeLogRecord(obj) + self.log_output += record.msg + '\n' + self.handled.set() + + def test_output(self): + # The log message sent to the DatagramHandler is properly received. + if self.server_exception: + self.skipTest(self.server_exception) + logger = logging.getLogger("udp") + logger.error("spam") + self.handled.wait() + self.handled.clear() + logger.error("eggs") + self.handled.wait() + self.assertEqual(self.log_output, "spam\neggs\n") + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") +class UnixDatagramHandlerTest(DatagramHandlerTest): + + """Test for DatagramHandler using Unix sockets.""" + + if hasattr(socket, "AF_UNIX"): + server_class = TestUnixDatagramServer + + def setUp(self): + # override the definition in the base class + self.address = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.address) + DatagramHandlerTest.setUp(self) + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class SysLogHandlerTest(BaseTest): + + """Test for SysLogHandler using UDP.""" + + server_class = TestUDPServer + address = ('localhost', 0) + + def setUp(self): + """Set up a UDP server to receive log messages, and a SysLogHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + # Issue #29177: deal with errors that happen during setup + self.server = self.sl_hdlr = self.server_exception = None + try: + self.server = server = self.server_class(self.address, + self.handle_datagram, 0.01) + server.start() + # Uncomment next line to test error recovery in setUp() + # raise OSError('dummy error raised') + except OSError as e: + self.server_exception = e + return + server.ready.wait() + hcls = logging.handlers.SysLogHandler + if isinstance(server.server_address, tuple): + self.sl_hdlr = hcls((server.server_address[0], server.port)) + else: + self.sl_hdlr = hcls(server.server_address) + self.log_output = b'' + self.root_logger.removeHandler(self.root_logger.handlers[0]) + self.root_logger.addHandler(self.sl_hdlr) + self.handled = threading.Event() + + def tearDown(self): + """Shutdown the server.""" + try: + if self.server: + self.server.stop() + if self.sl_hdlr: + self.root_logger.removeHandler(self.sl_hdlr) + self.sl_hdlr.close() + finally: + BaseTest.tearDown(self) + + def handle_datagram(self, request): + self.log_output = request.packet + self.handled.set() + + def test_output(self): + if self.server_exception: + self.skipTest(self.server_exception) + # The log message sent to the SysLogHandler is properly received. + logger = logging.getLogger("slh") + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m\x00') + self.handled.clear() + self.sl_hdlr.append_nul = False + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m') + self.handled.clear() + self.sl_hdlr.ident = "h\xe4m-" + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>h\xc3\xa4m-sp\xc3\xa4m') + + def test_udp_reconnection(self): + logger = logging.getLogger("slh") + self.sl_hdlr.close() + self.handled.clear() + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m\x00') + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") +class UnixSysLogHandlerTest(SysLogHandlerTest): + + """Test for SysLogHandler with Unix sockets.""" + + if hasattr(socket, "AF_UNIX"): + server_class = TestUnixDatagramServer + + def setUp(self): + # override the definition in the base class + self.address = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.address) + SysLogHandlerTest.setUp(self) + +@unittest.skipUnless(socket_helper.IPV6_ENABLED, + 'IPv6 support required for this test.') +class IPv6SysLogHandlerTest(SysLogHandlerTest): + + """Test for SysLogHandler with IPv6 host.""" + + server_class = TestUDPServer + address = ('::1', 0) + + def setUp(self): + self.server_class.address_family = socket.AF_INET6 + super(IPv6SysLogHandlerTest, self).setUp() + + def tearDown(self): + self.server_class.address_family = socket.AF_INET + super(IPv6SysLogHandlerTest, self).tearDown() + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class HTTPHandlerTest(BaseTest): + """Test for HTTPHandler.""" + + def setUp(self): + """Set up an HTTP server to receive log messages, and a HTTPHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + self.handled = threading.Event() + + def handle_request(self, request): + self.command = request.command + self.log_data = urlparse(request.path) + if self.command == 'POST': + try: + rlen = int(request.headers['Content-Length']) + self.post_data = request.rfile.read(rlen) + except: + self.post_data = None + request.send_response(200) + request.end_headers() + self.handled.set() + + # TODO: RUSTPYTHON + @unittest.skip("RUSTPYTHON") + def test_output(self): + # The log message sent to the HTTPHandler is properly received. + logger = logging.getLogger("http") + root_logger = self.root_logger + root_logger.removeHandler(self.root_logger.handlers[0]) + for secure in (False, True): + addr = ('localhost', 0) + if secure: + try: + import ssl + except ImportError: + sslctx = None + else: + here = os.path.dirname(__file__) + localhost_cert = os.path.join(here, "certdata", "keycert.pem") + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(localhost_cert) + + context = ssl.create_default_context(cafile=localhost_cert) + else: + sslctx = None + context = None + self.server = server = TestHTTPServer(addr, self.handle_request, + 0.01, sslctx=sslctx) + server.start() + server.ready.wait() + host = 'localhost:%d' % server.server_port + secure_client = secure and sslctx + self.h_hdlr = logging.handlers.HTTPHandler(host, '/frob', + secure=secure_client, + context=context, + credentials=('foo', 'bar')) + self.log_data = None + root_logger.addHandler(self.h_hdlr) + + for method in ('GET', 'POST'): + self.h_hdlr.method = method + self.handled.clear() + msg = "sp\xe4m" + logger.error(msg) + self.handled.wait() + self.assertEqual(self.log_data.path, '/frob') + self.assertEqual(self.command, method) + if method == 'GET': + d = parse_qs(self.log_data.query) + else: + d = parse_qs(self.post_data.decode('utf-8')) + self.assertEqual(d['name'], ['http']) + self.assertEqual(d['funcName'], ['test_output']) + self.assertEqual(d['msg'], [msg]) + + self.server.stop() + self.root_logger.removeHandler(self.h_hdlr) + self.h_hdlr.close() + +class MemoryTest(BaseTest): + + """Test memory persistence of logger objects.""" + + def setUp(self): + """Create a dict to remember potentially destroyed objects.""" + BaseTest.setUp(self) + self._survivors = {} + + def _watch_for_survival(self, *args): + """Watch the given objects for survival, by creating weakrefs to + them.""" + for obj in args: + key = id(obj), repr(obj) + self._survivors[key] = weakref.ref(obj) + + def _assertTruesurvival(self): + """Assert that all objects watched for survival have survived.""" + # Trigger cycle breaking. + gc.collect() + dead = [] + for (id_, repr_), ref in self._survivors.items(): + if ref() is None: + dead.append(repr_) + if dead: + self.fail("%d objects should have survived " + "but have been destroyed: %s" % (len(dead), ", ".join(dead))) + + def test_persistent_loggers(self): + # Logger objects are persistent and retain their configuration, even + # if visible references are destroyed. + self.root_logger.setLevel(logging.INFO) + foo = logging.getLogger("foo") + self._watch_for_survival(foo) + foo.setLevel(logging.DEBUG) + self.root_logger.debug(self.next_message()) + foo.debug(self.next_message()) + self.assert_log_lines([ + ('foo', 'DEBUG', '2'), + ]) + del foo + # foo has survived. + self._assertTruesurvival() + # foo has retained its settings. + bar = logging.getLogger("foo") + bar.debug(self.next_message()) + self.assert_log_lines([ + ('foo', 'DEBUG', '2'), + ('foo', 'DEBUG', '3'), + ]) + + +class EncodingTest(BaseTest): + def test_encoding_plain_file(self): + # In Python 2.x, a plain file object is treated as having no encoding. + log = logging.getLogger("test") + fn = make_temp_file(".log", "test_logging-1-") + # the non-ascii data we write to the log. + data = "foo\x80" + try: + handler = logging.FileHandler(fn, encoding="utf-8") + log.addHandler(handler) + try: + # write non-ascii data to the log. + log.warning(data) + finally: + log.removeHandler(handler) + handler.close() + # check we wrote exactly those bytes, ignoring trailing \n etc + f = open(fn, encoding="utf-8") + try: + self.assertEqual(f.read().rstrip(), data) + finally: + f.close() + finally: + if os.path.isfile(fn): + os.remove(fn) + + def test_encoding_cyrillic_unicode(self): + log = logging.getLogger("test") + # Get a message in Unicode: Do svidanya in Cyrillic (meaning goodbye) + message = '\u0434\u043e \u0441\u0432\u0438\u0434\u0430\u043d\u0438\u044f' + # Ensure it's written in a Cyrillic encoding + writer_class = codecs.getwriter('cp1251') + writer_class.encoding = 'cp1251' + stream = io.BytesIO() + writer = writer_class(stream, 'strict') + handler = logging.StreamHandler(writer) + log.addHandler(handler) + try: + log.warning(message) + finally: + log.removeHandler(handler) + handler.close() + # check we wrote exactly those bytes, ignoring trailing \n etc + s = stream.getvalue() + # Compare against what the data should be when encoded in CP-1251 + self.assertEqual(s, b'\xe4\xee \xf1\xe2\xe8\xe4\xe0\xed\xe8\xff\n') + + +class WarningsTest(BaseTest): + + def test_warnings(self): + with warnings.catch_warnings(): + logging.captureWarnings(True) + self.addCleanup(logging.captureWarnings, False) + warnings.filterwarnings("always", category=UserWarning) + stream = io.StringIO() + h = logging.StreamHandler(stream) + logger = logging.getLogger("py.warnings") + logger.addHandler(h) + warnings.warn("I'm warning you...") + logger.removeHandler(h) + s = stream.getvalue() + h.close() + self.assertGreater(s.find("UserWarning: I'm warning you...\n"), 0) + + # See if an explicit file uses the original implementation + a_file = io.StringIO() + warnings.showwarning("Explicit", UserWarning, "dummy.py", 42, + a_file, "Dummy line") + s = a_file.getvalue() + a_file.close() + self.assertEqual(s, + "dummy.py:42: UserWarning: Explicit\n Dummy line\n") + + def test_warnings_no_handlers(self): + with warnings.catch_warnings(): + logging.captureWarnings(True) + self.addCleanup(logging.captureWarnings, False) + + # confirm our assumption: no loggers are set + logger = logging.getLogger("py.warnings") + self.assertEqual(logger.handlers, []) + + warnings.showwarning("Explicit", UserWarning, "dummy.py", 42) + self.assertEqual(len(logger.handlers), 1) + self.assertIsInstance(logger.handlers[0], logging.NullHandler) + + +def formatFunc(format, datefmt=None): + return logging.Formatter(format, datefmt) + +class myCustomFormatter: + def __init__(self, fmt, datefmt=None): + pass + +def handlerFunc(): + return logging.StreamHandler() + +class CustomHandler(logging.StreamHandler): + pass + +class CustomListener(logging.handlers.QueueListener): + pass + +class CustomQueue(queue.Queue): + pass + +class CustomQueueProtocol: + def __init__(self, maxsize=0): + self.queue = queue.Queue(maxsize) + + def __getattr__(self, attribute): + queue = object.__getattribute__(self, 'queue') + return getattr(queue, attribute) + +class CustomQueueFakeProtocol(CustomQueueProtocol): + # An object implementing the minimial Queue API for + # the logging module but with incorrect signatures. + # + # The object will be considered a valid queue class since we + # do not check the signatures (only callability of methods) + # but will NOT be usable in production since a TypeError will + # be raised due to the extra argument in 'put_nowait'. + def put_nowait(self): + pass + +class CustomQueueWrongProtocol(CustomQueueProtocol): + put_nowait = None + +class MinimalQueueProtocol: + def put_nowait(self, x): pass + def get(self): pass + +def queueMaker(): + return queue.Queue() + +def listenerMaker(arg1, arg2, respect_handler_level=False): + def func(queue, *handlers, **kwargs): + kwargs.setdefault('respect_handler_level', respect_handler_level) + return CustomListener(queue, *handlers, **kwargs) + return func + +class ConfigDictTest(BaseTest): + + """Reading logging config from a dictionary.""" + + check_no_resource_warning = warnings_helper.check_no_resource_warning + expected_log_pat = r"^(\w+) \+\+ (\w+)$" + + # config0 is a standard configuration. + config0 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # config1 adds a little to the standard configuration. + config1 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config1a moves the handler to the root. Used with config8a + config1a = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # config2 has a subtle configuration error that should be reported + config2 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdbout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config1 but with a misspelt level on a handler + config2a = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NTOSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + + # As config1 but with a misspelt level on a logger + config2b = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WRANING', + }, + } + + # config3 has a less subtle configuration error + config3 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'misspelled_name', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config4 specifies a custom formatter class to be loaded + config4 = { + 'version': 1, + 'formatters': { + 'form1' : { + '()' : __name__ + '.ExceptionFormatter', + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'NOTSET', + 'handlers' : ['hand1'], + }, + } + + # As config4 but using an actual callable rather than a string + config4a = { + 'version': 1, + 'formatters': { + 'form1' : { + '()' : ExceptionFormatter, + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + 'form2' : { + '()' : __name__ + '.formatFunc', + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + 'form3' : { + '()' : formatFunc, + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + 'hand2' : { + '()' : handlerFunc, + }, + }, + 'root' : { + 'level' : 'NOTSET', + 'handlers' : ['hand1'], + }, + } + + # config5 specifies a custom handler class to be loaded + config5 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : __name__ + '.CustomHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config6 specifies a custom handler class to be loaded + # but has bad arguments + config6 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : __name__ + '.CustomHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + '9' : 'invalid parameter name', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config 7 does not define compiler.parser but defines compiler.lexer + # so compiler.parser should be disabled after applying it + config7 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.lexer' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config8 defines both compiler and compiler.lexer + # so compiler.parser should not be disabled (since + # compiler is defined) + config8 = { + 'version': 1, + 'disable_existing_loggers' : False, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + 'compiler.lexer' : { + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config8a disables existing loggers + config8a = { + 'version': 1, + 'disable_existing_loggers' : True, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + 'compiler.lexer' : { + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + config9 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'WARNING', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'NOTSET', + }, + } + + config9a = { + 'version': 1, + 'incremental' : True, + 'handlers' : { + 'hand1' : { + 'level' : 'WARNING', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'INFO', + }, + }, + } + + config9b = { + 'version': 1, + 'incremental' : True, + 'handlers' : { + 'hand1' : { + 'level' : 'INFO', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'INFO', + }, + }, + } + + # As config1 but with a filter added + config10 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'filters' : { + 'filt1' : { + 'name' : 'compiler.parser', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + 'filters' : ['filt1'], + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'filters' : ['filt1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # As config1 but using cfg:// references + config11 = { + 'version': 1, + 'true_formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handler_configs': { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'formatters' : 'cfg://true_formatters', + 'handlers' : { + 'hand1' : 'cfg://handler_configs[hand1]', + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config11 but missing the version key + config12 = { + 'true_formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handler_configs': { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'formatters' : 'cfg://true_formatters', + 'handlers' : { + 'hand1' : 'cfg://handler_configs[hand1]', + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config11 but using an unsupported version + config13 = { + 'version': 2, + 'true_formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handler_configs': { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'formatters' : 'cfg://true_formatters', + 'handlers' : { + 'hand1' : 'cfg://handler_configs[hand1]', + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config0, but with properties + config14 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + '.': { + 'foo': 'bar', + 'terminator': '!\n', + } + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # config0 but with default values for formatter. Skipped 15, it is defined + # in the test code. + config16 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(message)s ++ %(customfield)s', + 'defaults': {"customfield": "defaultvalue"} + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + class CustomFormatter(logging.Formatter): + custom_property = "." + + def format(self, record): + return super().format(record) + + config17 = { + 'version': 1, + 'formatters': { + "custom": { + "()": CustomFormatter, + "style": "{", + "datefmt": "%Y-%m-%d %H:%M:%S", + "format": "{message}", # <-- to force an exception when configuring + ".": { + "custom_property": "value" + } + } + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'custom', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + config18 = { + "version": 1, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + }, + "buffering": { + "class": "logging.handlers.MemoryHandler", + "capacity": 5, + "target": "console", + "level": "DEBUG", + "flushLevel": "ERROR" + } + }, + "loggers": { + "mymodule": { + "level": "DEBUG", + "handlers": ["buffering"], + "propagate": "true" + } + } + } + + bad_format = { + "version": 1, + "formatters": { + "mySimpleFormatter": { + "format": "%(asctime)s (%(name)s) %(levelname)s: %(message)s", + "style": "$" + } + }, + "handlers": { + "fileGlobal": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "mySimpleFormatter" + }, + "bufferGlobal": { + "class": "logging.handlers.MemoryHandler", + "capacity": 5, + "formatter": "mySimpleFormatter", + "target": "fileGlobal", + "level": "DEBUG" + } + }, + "loggers": { + "mymodule": { + "level": "DEBUG", + "handlers": ["bufferGlobal"], + "propagate": "true" + } + } + } + + # Configuration with custom logging.Formatter subclass as '()' key and 'validate' set to False + custom_formatter_class_validate = { + 'version': 1, + 'formatters': { + 'form1': { + '()': __name__ + '.ExceptionFormatter', + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom logging.Formatter subclass as 'class' key and 'validate' set to False + custom_formatter_class_validate2 = { + 'version': 1, + 'formatters': { + 'form1': { + 'class': __name__ + '.ExceptionFormatter', + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom class that is not inherited from logging.Formatter + custom_formatter_class_validate3 = { + 'version': 1, + 'formatters': { + 'form1': { + 'class': __name__ + '.myCustomFormatter', + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom function, 'validate' set to False and no defaults + custom_formatter_with_function = { + 'version': 1, + 'formatters': { + 'form1': { + '()': formatFunc, + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom function, and defaults + custom_formatter_with_defaults = { + 'version': 1, + 'formatters': { + 'form1': { + '()': formatFunc, + 'format': '%(levelname)s:%(name)s:%(message)s:%(customfield)s', + 'defaults': {"customfield": "myvalue"} + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + config_queue_handler = { + 'version': 1, + 'handlers' : { + 'h1' : { + 'class': 'logging.FileHandler', + }, + # key is before depended on handlers to test that deferred config works + 'ah' : { + 'class': 'logging.handlers.QueueHandler', + 'handlers': ['h1'] + }, + }, + "root": { + "level": "DEBUG", + "handlers": ["ah"] + } + } + + def apply_config(self, conf): + logging.config.dictConfig(conf) + + def check_handler(self, name, cls): + h = logging.getHandlerByName(name) + self.assertIsInstance(h, cls) + + def test_config0_ok(self): + # A simple config which overrides the default settings. + with support.captured_stdout() as output: + self.apply_config(self.config0) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config1_ok(self, config=config1): + # A config defining a sub-parser as well. + with support.captured_stdout() as output: + self.apply_config(config) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config2_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2) + + def test_config2a_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2a) + + def test_config2b_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2b) + + def test_config3_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config3) + + def test_config4_ok(self): + # A config specifying a custom formatter class. + with support.captured_stdout() as output: + self.apply_config(self.config4) + self.check_handler('hand1', logging.StreamHandler) + #logger = logging.getLogger() + try: + raise RuntimeError() + except RuntimeError: + logging.exception("just testing") + sys.stdout.seek(0) + self.assertEqual(output.getvalue(), + "ERROR:root:just testing\nGot a [RuntimeError]\n") + # Original logger output is empty + self.assert_log_lines([]) + + def test_config4a_ok(self): + # A config specifying a custom formatter class. + with support.captured_stdout() as output: + self.apply_config(self.config4a) + #logger = logging.getLogger() + try: + raise RuntimeError() + except RuntimeError: + logging.exception("just testing") + sys.stdout.seek(0) + self.assertEqual(output.getvalue(), + "ERROR:root:just testing\nGot a [RuntimeError]\n") + # Original logger output is empty + self.assert_log_lines([]) + + def test_config5_ok(self): + self.test_config1_ok(config=self.config5) + self.check_handler('hand1', CustomHandler) + + def test_config6_failure(self): + self.assertRaises(Exception, self.apply_config, self.config6) + + def test_config7_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config7) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + self.assertTrue(logger.disabled) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + # Same as test_config_7_ok but don't disable old loggers. + def test_config_8_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1) + logger = logging.getLogger("compiler.parser") + # All will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config8) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + self.assertFalse(logger.disabled) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ('INFO', '5'), + ('ERROR', '6'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config_8a_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1a) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + # See issue #11424. compiler-hyphenated sorts + # between compiler and compiler.xyz and this + # was preventing compiler.xyz from being included + # in the child loggers of compiler because of an + # overzealous loop termination condition. + hyphenated = logging.getLogger('compiler-hyphenated') + # All will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ('CRITICAL', '3'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config8a) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + self.assertFalse(logger.disabled) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + # Will not appear + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '4'), + ('ERROR', '5'), + ('INFO', '6'), + ('ERROR', '7'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config_9_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config9) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + # Nothing will be output since both handler and logger are set to WARNING + logger.info(self.next_message()) + self.assert_log_lines([], stream=output) + self.apply_config(self.config9a) + # Nothing will be output since handler is still set to WARNING + logger.info(self.next_message()) + self.assert_log_lines([], stream=output) + self.apply_config(self.config9b) + # Message should now be output + logger.info(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ], stream=output) + + def test_config_10_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config10) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + logger.warning(self.next_message()) + logger = logging.getLogger('compiler') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger('compiler.lexer') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger("compiler.parser.codegen") + # Output, as not filtered + logger.error(self.next_message()) + self.assert_log_lines([ + ('WARNING', '1'), + ('ERROR', '4'), + ], stream=output) + + def test_config11_ok(self): + self.test_config1_ok(self.config11) + + def test_config12_failure(self): + self.assertRaises(Exception, self.apply_config, self.config12) + + def test_config13_failure(self): + self.assertRaises(Exception, self.apply_config, self.config13) + + def test_config14_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config14) + h = logging._handlers['hand1'] + self.assertEqual(h.foo, 'bar') + self.assertEqual(h.terminator, '!\n') + logging.warning('Exclamation') + self.assertTrue(output.getvalue().endswith('Exclamation!\n')) + + def test_config15_ok(self): + + with self.check_no_resource_warning(): + fn = make_temp_file(".log", "test_logging-X-") + + config = { + "version": 1, + "handlers": { + "file": { + "class": "logging.FileHandler", + "filename": fn, + "encoding": "utf-8", + } + }, + "root": { + "handlers": ["file"] + } + } + + self.apply_config(config) + self.apply_config(config) + + handler = logging.root.handlers[0] + self.addCleanup(closeFileHandler, handler, fn) + + def test_config16_ok(self): + self.apply_config(self.config16) + h = logging._handlers['hand1'] + + # Custom value + result = h.formatter.format(logging.makeLogRecord( + {'msg': 'Hello', 'customfield': 'customvalue'})) + self.assertEqual(result, 'Hello ++ customvalue') + + # Default value + result = h.formatter.format(logging.makeLogRecord( + {'msg': 'Hello'})) + self.assertEqual(result, 'Hello ++ defaultvalue') + + def test_config17_ok(self): + self.apply_config(self.config17) + h = logging._handlers['hand1'] + self.assertEqual(h.formatter.custom_property, 'value') + + def test_config18_ok(self): + self.apply_config(self.config18) + handler = logging.getLogger('mymodule').handlers[0] + self.assertEqual(handler.flushLevel, logging.ERROR) + + def setup_via_listener(self, text, verify=None): + text = text.encode("utf-8") + # Ask for a randomly assigned port (by using port 0) + t = logging.config.listen(0, verify) + t.start() + t.ready.wait() + # Now get the port allocated + port = t.port + t.ready.clear() + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + sock.connect(('localhost', port)) + + slen = struct.pack('>L', len(text)) + s = slen + text + sentsofar = 0 + left = len(s) + while left > 0: + sent = sock.send(s[sentsofar:]) + sentsofar += sent + left -= sent + sock.close() + finally: + t.ready.wait(2.0) + logging.config.stopListening() + threading_helper.join_thread(t) + + @support.requires_working_socket() + def test_listen_config_10_ok(self): + with support.captured_stdout() as output: + self.setup_via_listener(json.dumps(self.config10)) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + logger.warning(self.next_message()) + logger = logging.getLogger('compiler') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger('compiler.lexer') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger("compiler.parser.codegen") + # Output, as not filtered + logger.error(self.next_message()) + self.assert_log_lines([ + ('WARNING', '1'), + ('ERROR', '4'), + ], stream=output) + + @support.requires_working_socket() + def test_listen_config_1_ok(self): + with support.captured_stdout() as output: + self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1)) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + @support.requires_working_socket() + def test_listen_verify(self): + + def verify_fail(stuff): + return None + + def verify_reverse(stuff): + return stuff[::-1] + + logger = logging.getLogger("compiler.parser") + to_send = textwrap.dedent(ConfigFileTest.config1) + # First, specify a verification function that will fail. + # We expect to see no output, since our configuration + # never took effect. + with support.captured_stdout() as output: + self.setup_via_listener(to_send, verify_fail) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([], stream=output) + # Original logger output has the stuff we logged. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + # Now, perform no verification. Our configuration + # should take effect. + + with support.captured_stdout() as output: + self.setup_via_listener(to_send) # no verify callable specified + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ], stream=output) + # Original logger output still has the stuff we logged before. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + # Now, perform verification which transforms the bytes. + + with support.captured_stdout() as output: + self.setup_via_listener(to_send[::-1], verify_reverse) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '5'), + ('ERROR', '6'), + ], stream=output) + # Original logger output still has the stuff we logged before. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + def test_bad_format(self): + self.assertRaises(ValueError, self.apply_config, self.bad_format) + + def test_bad_format_with_dollar_style(self): + config = copy.deepcopy(self.bad_format) + config['formatters']['mySimpleFormatter']['format'] = "${asctime} (${name}) ${levelname}: ${message}" + + self.apply_config(config) + handler = logging.getLogger('mymodule').handlers[0] + self.assertIsInstance(handler.target, logging.Handler) + self.assertIsInstance(handler.formatter._style, + logging.StringTemplateStyle) + self.assertEqual(sorted(logging.getHandlerNames()), + ['bufferGlobal', 'fileGlobal']) + + def test_custom_formatter_class_with_validate(self): + self.apply_config(self.custom_formatter_class_validate) + handler = logging.getLogger("my_test_logger_custom_formatter").handlers[0] + self.assertIsInstance(handler.formatter, ExceptionFormatter) + + def test_custom_formatter_class_with_validate2(self): + self.apply_config(self.custom_formatter_class_validate2) + handler = logging.getLogger("my_test_logger_custom_formatter").handlers[0] + self.assertIsInstance(handler.formatter, ExceptionFormatter) + + def test_custom_formatter_class_with_validate2_with_wrong_fmt(self): + config = self.custom_formatter_class_validate.copy() + config['formatters']['form1']['style'] = "$" + + # Exception should not be raised as we have configured 'validate' to False + self.apply_config(config) + handler = logging.getLogger("my_test_logger_custom_formatter").handlers[0] + self.assertIsInstance(handler.formatter, ExceptionFormatter) + + def test_custom_formatter_class_with_validate3(self): + self.assertRaises(ValueError, self.apply_config, self.custom_formatter_class_validate3) + + def test_custom_formatter_function_with_validate(self): + self.assertRaises(ValueError, self.apply_config, self.custom_formatter_with_function) + + def test_custom_formatter_function_with_defaults(self): + self.assertRaises(ValueError, self.apply_config, self.custom_formatter_with_defaults) + + def test_baseconfig(self): + d = { + 'atuple': (1, 2, 3), + 'alist': ['a', 'b', 'c'], + 'adict': {'d': 'e', 'f': 3 }, + 'nest1': ('g', ('h', 'i'), 'j'), + 'nest2': ['k', ['l', 'm'], 'n'], + 'nest3': ['o', 'cfg://alist', 'p'], + } + bc = logging.config.BaseConfigurator(d) + self.assertEqual(bc.convert('cfg://atuple[1]'), 2) + self.assertEqual(bc.convert('cfg://alist[1]'), 'b') + self.assertEqual(bc.convert('cfg://nest1[1][0]'), 'h') + self.assertEqual(bc.convert('cfg://nest2[1][1]'), 'm') + self.assertEqual(bc.convert('cfg://adict.d'), 'e') + self.assertEqual(bc.convert('cfg://adict[f]'), 3) + v = bc.convert('cfg://nest3') + self.assertEqual(v.pop(1), ['a', 'b', 'c']) + self.assertRaises(KeyError, bc.convert, 'cfg://nosuch') + self.assertRaises(ValueError, bc.convert, 'cfg://!') + self.assertRaises(KeyError, bc.convert, 'cfg://adict[2]') + + def test_namedtuple(self): + # see bpo-39142 + from collections import namedtuple + + class MyHandler(logging.StreamHandler): + def __init__(self, resource, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resource: namedtuple = resource + + def emit(self, record): + record.msg += f' {self.resource.type}' + return super().emit(record) + + Resource = namedtuple('Resource', ['type', 'labels']) + resource = Resource(type='my_type', labels=['a']) + + config = { + 'version': 1, + 'handlers': { + 'myhandler': { + '()': MyHandler, + 'resource': resource + } + }, + 'root': {'level': 'INFO', 'handlers': ['myhandler']}, + } + with support.captured_stderr() as stderr: + self.apply_config(config) + logging.info('some log') + self.assertEqual(stderr.getvalue(), 'some log my_type\n') + + def test_config_callable_filter_works(self): + def filter_(_): + return 1 + self.apply_config({ + "version": 1, "root": {"level": "DEBUG", "filters": [filter_]} + }) + assert logging.getLogger().filters[0] is filter_ + logging.getLogger().filters = [] + + def test_config_filter_works(self): + filter_ = logging.Filter("spam.eggs") + self.apply_config({ + "version": 1, "root": {"level": "DEBUG", "filters": [filter_]} + }) + assert logging.getLogger().filters[0] is filter_ + logging.getLogger().filters = [] + + def test_config_filter_method_works(self): + class FakeFilter: + def filter(self, _): + return 1 + filter_ = FakeFilter() + self.apply_config({ + "version": 1, "root": {"level": "DEBUG", "filters": [filter_]} + }) + assert logging.getLogger().filters[0] is filter_ + logging.getLogger().filters = [] + + def test_invalid_type_raises(self): + class NotAFilter: pass + for filter_ in [None, 1, NotAFilter()]: + self.assertRaises( + ValueError, + self.apply_config, + {"version": 1, "root": {"level": "DEBUG", "filters": [filter_]}} + ) + + def do_queuehandler_configuration(self, qspec, lspec): + cd = copy.deepcopy(self.config_queue_handler) + fn = make_temp_file('.log', 'test_logging-cqh-') + cd['handlers']['h1']['filename'] = fn + if qspec is not None: + cd['handlers']['ah']['queue'] = qspec + if lspec is not None: + cd['handlers']['ah']['listener'] = lspec + qh = None + try: + self.apply_config(cd) + qh = logging.getHandlerByName('ah') + self.assertEqual(sorted(logging.getHandlerNames()), ['ah', 'h1']) + self.assertIsNotNone(qh.listener) + qh.listener.start() + logging.debug('foo') + logging.info('bar') + logging.warning('baz') + + # Need to let the listener thread finish its work + while support.sleeping_retry(support.LONG_TIMEOUT, + "queue not empty"): + if qh.listener.queue.empty(): + break + + # wait until the handler completed its last task + qh.listener.queue.join() + + with open(fn, encoding='utf-8') as f: + data = f.read().splitlines() + self.assertEqual(data, ['foo', 'bar', 'baz']) + finally: + if qh: + qh.listener.stop() + h = logging.getHandlerByName('h1') + if h: + self.addCleanup(closeFileHandler, h, fn) + else: + self.addCleanup(os.remove, fn) + + @threading_helper.requires_working_threading() + @support.requires_subprocess() + def test_config_queue_handler(self): + qs = [CustomQueue(), CustomQueueProtocol()] + dqs = [{'()': f'{__name__}.{cls}', 'maxsize': 10} + for cls in ['CustomQueue', 'CustomQueueProtocol']] + dl = { + '()': __name__ + '.listenerMaker', + 'arg1': None, + 'arg2': None, + 'respect_handler_level': True + } + qvalues = (None, __name__ + '.queueMaker', __name__ + '.CustomQueue', *dqs, *qs) + lvalues = (None, __name__ + '.CustomListener', dl, CustomListener) + for qspec, lspec in itertools.product(qvalues, lvalues): + self.do_queuehandler_configuration(qspec, lspec) + + # Some failure cases + qvalues = (None, 4, int, '', 'foo') + lvalues = (None, 4, int, '', 'bar') + for qspec, lspec in itertools.product(qvalues, lvalues): + if lspec is None and qspec is None: + continue + with self.assertRaises(ValueError) as ctx: + self.do_queuehandler_configuration(qspec, lspec) + msg = str(ctx.exception) + self.assertEqual(msg, "Unable to configure handler 'ah'") + + def _apply_simple_queue_listener_configuration(self, qspec): + self.apply_config({ + "version": 1, + "handlers": { + "queue_listener": { + "class": "logging.handlers.QueueHandler", + "queue": qspec, + }, + }, + }) + + @threading_helper.requires_working_threading() + @support.requires_subprocess() + @patch("multiprocessing.Manager") + def test_config_queue_handler_does_not_create_multiprocessing_manager(self, manager): + # gh-120868, gh-121723, gh-124653 + + for qspec in [ + {"()": "queue.Queue", "maxsize": -1}, + queue.Queue(), + # queue.SimpleQueue does not inherit from queue.Queue + queue.SimpleQueue(), + # CustomQueueFakeProtocol passes the checks but will not be usable + # since the signatures are incompatible. Checking the Queue API + # without testing the type of the actual queue is a trade-off + # between usability and the work we need to do in order to safely + # check that the queue object correctly implements the API. + CustomQueueFakeProtocol(), + MinimalQueueProtocol(), + ]: + with self.subTest(qspec=qspec): + self._apply_simple_queue_listener_configuration(qspec) + manager.assert_not_called() + + @patch("multiprocessing.Manager") + def test_config_queue_handler_invalid_config_does_not_create_multiprocessing_manager(self, manager): + # gh-120868, gh-121723 + + for qspec in [object(), CustomQueueWrongProtocol()]: + with self.subTest(qspec=qspec), self.assertRaises(ValueError): + self._apply_simple_queue_listener_configuration(qspec) + manager.assert_not_called() + + @skip_if_tsan_fork + @support.requires_subprocess() + @unittest.skipUnless(support.Py_DEBUG, "requires a debug build for testing" + " assertions in multiprocessing") + def test_config_reject_simple_queue_handler_multiprocessing_context(self): + # multiprocessing.SimpleQueue does not implement 'put_nowait' + # and thus cannot be used as a queue-like object (gh-124653) + + import multiprocessing + + if support.MS_WINDOWS: + start_methods = ['spawn'] + else: + start_methods = ['spawn', 'fork', 'forkserver'] + + for start_method in start_methods: + with self.subTest(start_method=start_method): + ctx = multiprocessing.get_context(start_method) + qspec = ctx.SimpleQueue() + with self.assertRaises(ValueError): + self._apply_simple_queue_listener_configuration(qspec) + + @skip_if_tsan_fork + @support.requires_subprocess() + @unittest.skipUnless(support.Py_DEBUG, "requires a debug build for testing" + "assertions in multiprocessing") + def test_config_queue_handler_multiprocessing_context(self): + # regression test for gh-121723 + if support.MS_WINDOWS: + start_methods = ['spawn'] + else: + start_methods = ['spawn', 'fork', 'forkserver'] + for start_method in start_methods: + with self.subTest(start_method=start_method): + ctx = multiprocessing.get_context(start_method) + with ctx.Manager() as manager: + q = manager.Queue() + records = [] + # use 1 process and 1 task per child to put 1 record + with ctx.Pool(1, initializer=self._mpinit_issue121723, + initargs=(q, "text"), maxtasksperchild=1): + records.append(q.get(timeout=60)) + self.assertTrue(q.empty()) + self.assertEqual(len(records), 1) + + @staticmethod + def _mpinit_issue121723(qspec, message_to_log): + # static method for pickling support + logging.config.dictConfig({ + 'version': 1, + 'disable_existing_loggers': True, + 'handlers': { + 'log_to_parent': { + 'class': 'logging.handlers.QueueHandler', + 'queue': qspec + } + }, + 'root': {'handlers': ['log_to_parent'], 'level': 'DEBUG'} + }) + # log a message (this creates a record put in the queue) + logging.getLogger().info(message_to_log) + + # TODO: RustPython + @unittest.expectedFailure + @support.requires_subprocess() + def test_multiprocessing_queues(self): + # See gh-119819 + + cd = copy.deepcopy(self.config_queue_handler) + from multiprocessing import Queue as MQ, Manager as MM + q1 = MQ() # this can't be pickled + q2 = MM().Queue() # a proxy queue for use when pickling is needed + q3 = MM().JoinableQueue() # a joinable proxy queue + for qspec in (q1, q2, q3): + fn = make_temp_file('.log', 'test_logging-cmpqh-') + cd['handlers']['h1']['filename'] = fn + cd['handlers']['ah']['queue'] = qspec + qh = None + try: + self.apply_config(cd) + qh = logging.getHandlerByName('ah') + self.assertEqual(sorted(logging.getHandlerNames()), ['ah', 'h1']) + self.assertIsNotNone(qh.listener) + self.assertIs(qh.queue, qspec) + self.assertIs(qh.listener.queue, qspec) + finally: + h = logging.getHandlerByName('h1') + if h: + self.addCleanup(closeFileHandler, h, fn) + else: + self.addCleanup(os.remove, fn) + + def test_90195(self): + # See gh-90195 + config = { + 'version': 1, + 'disable_existing_loggers': False, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + }, + }, + 'loggers': { + 'a': { + 'level': 'DEBUG', + 'handlers': ['console'] + } + } + } + logger = logging.getLogger('a') + self.assertFalse(logger.disabled) + self.apply_config(config) + self.assertFalse(logger.disabled) + # Should disable all loggers ... + self.apply_config({'version': 1}) + self.assertTrue(logger.disabled) + del config['disable_existing_loggers'] + self.apply_config(config) + # Logger should be enabled, since explicitly mentioned + self.assertFalse(logger.disabled) + + # TODO: RustPython + @unittest.expectedFailure + def test_111615(self): + # See gh-111615 + import_helper.import_module('_multiprocessing') # see gh-113692 + mp = import_helper.import_module('multiprocessing') + + config = { + 'version': 1, + 'handlers': { + 'sink': { + 'class': 'logging.handlers.QueueHandler', + 'queue': mp.get_context('spawn').Queue(), + }, + }, + 'root': { + 'handlers': ['sink'], + 'level': 'DEBUG', + }, + } + logging.config.dictConfig(config) + + # gh-118868: check if kwargs are passed to logging QueueHandler + def test_kwargs_passing(self): + class CustomQueueHandler(logging.handlers.QueueHandler): + def __init__(self, *args, **kwargs): + super().__init__(queue.Queue()) + self.custom_kwargs = kwargs + + custom_kwargs = {'foo': 'bar'} + + config = { + 'version': 1, + 'handlers': { + 'custom': { + 'class': CustomQueueHandler, + **custom_kwargs + }, + }, + 'root': { + 'level': 'DEBUG', + 'handlers': ['custom'] + } + } + + logging.config.dictConfig(config) + + handler = logging.getHandlerByName('custom') + self.assertEqual(handler.custom_kwargs, custom_kwargs) + + +class ManagerTest(BaseTest): + def test_manager_loggerclass(self): + logged = [] + + class MyLogger(logging.Logger): + def _log(self, level, msg, args, exc_info=None, extra=None): + logged.append(msg) + + man = logging.Manager(None) + self.assertRaises(TypeError, man.setLoggerClass, int) + man.setLoggerClass(MyLogger) + logger = man.getLogger('test') + logger.warning('should appear in logged') + logging.warning('should not appear in logged') + + self.assertEqual(logged, ['should appear in logged']) + + def test_set_log_record_factory(self): + man = logging.Manager(None) + expected = object() + man.setLogRecordFactory(expected) + self.assertEqual(man.logRecordFactory, expected) + +class ChildLoggerTest(BaseTest): + def test_child_loggers(self): + r = logging.getLogger() + l1 = logging.getLogger('abc') + l2 = logging.getLogger('def.ghi') + c1 = r.getChild('xyz') + c2 = r.getChild('uvw.xyz') + self.assertIs(c1, logging.getLogger('xyz')) + self.assertIs(c2, logging.getLogger('uvw.xyz')) + c1 = l1.getChild('def') + c2 = c1.getChild('ghi') + c3 = l1.getChild('def.ghi') + self.assertIs(c1, logging.getLogger('abc.def')) + self.assertIs(c2, logging.getLogger('abc.def.ghi')) + self.assertIs(c2, c3) + + def test_get_children(self): + r = logging.getLogger() + l1 = logging.getLogger('foo') + l2 = logging.getLogger('foo.bar') + l3 = logging.getLogger('foo.bar.baz.bozz') + l4 = logging.getLogger('bar') + kids = r.getChildren() + expected = {l1, l4} + self.assertEqual(expected, kids & expected) # might be other kids for root + self.assertNotIn(l2, expected) + kids = l1.getChildren() + self.assertEqual({l2}, kids) + kids = l2.getChildren() + self.assertEqual(set(), kids) + +class DerivedLogRecord(logging.LogRecord): + pass + +class LogRecordFactoryTest(BaseTest): + + def setUp(self): + class CheckingFilter(logging.Filter): + def __init__(self, cls): + self.cls = cls + + def filter(self, record): + t = type(record) + if t is not self.cls: + msg = 'Unexpected LogRecord type %s, expected %s' % (t, + self.cls) + raise TypeError(msg) + return True + + BaseTest.setUp(self) + self.filter = CheckingFilter(DerivedLogRecord) + self.root_logger.addFilter(self.filter) + self.orig_factory = logging.getLogRecordFactory() + + def tearDown(self): + self.root_logger.removeFilter(self.filter) + BaseTest.tearDown(self) + logging.setLogRecordFactory(self.orig_factory) + + def test_logrecord_class(self): + self.assertRaises(TypeError, self.root_logger.warning, + self.next_message()) + logging.setLogRecordFactory(DerivedLogRecord) + self.root_logger.error(self.next_message()) + self.assert_log_lines([ + ('root', 'ERROR', '2'), + ]) + + +@threading_helper.requires_working_threading() +class QueueHandlerTest(BaseTest): + # Do not bother with a logger name group. + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" + + def setUp(self): + BaseTest.setUp(self) + self.queue = queue.Queue(-1) + self.que_hdlr = logging.handlers.QueueHandler(self.queue) + self.name = 'que' + self.que_logger = logging.getLogger('que') + self.que_logger.propagate = False + self.que_logger.setLevel(logging.WARNING) + self.que_logger.addHandler(self.que_hdlr) + + def tearDown(self): + self.que_hdlr.close() + BaseTest.tearDown(self) + + def test_queue_handler(self): + self.que_logger.debug(self.next_message()) + self.assertRaises(queue.Empty, self.queue.get_nowait) + self.que_logger.info(self.next_message()) + self.assertRaises(queue.Empty, self.queue.get_nowait) + msg = self.next_message() + self.que_logger.warning(msg) + data = self.queue.get_nowait() + self.assertTrue(isinstance(data, logging.LogRecord)) + self.assertEqual(data.name, self.que_logger.name) + self.assertEqual((data.msg, data.args), (msg, None)) + + def test_formatting(self): + msg = self.next_message() + levelname = logging.getLevelName(logging.WARNING) + log_format_str = '{name} -> {levelname}: {message}' + formatted_msg = log_format_str.format(name=self.name, + levelname=levelname, message=msg) + formatter = logging.Formatter(self.log_format) + self.que_hdlr.setFormatter(formatter) + self.que_logger.warning(msg) + log_record = self.queue.get_nowait() + self.assertEqual(formatted_msg, log_record.msg) + self.assertEqual(formatted_msg, log_record.message) + + @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), + 'logging.handlers.QueueListener required for this test') + def test_queue_listener(self): + handler = TestHandler(support.Matcher()) + listener = logging.handlers.QueueListener(self.queue, handler) + listener.start() + try: + self.que_logger.warning(self.next_message()) + self.que_logger.error(self.next_message()) + self.que_logger.critical(self.next_message()) + finally: + listener.stop() + self.assertTrue(handler.matches(levelno=logging.WARNING, message='1')) + self.assertTrue(handler.matches(levelno=logging.ERROR, message='2')) + self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='3')) + handler.close() + + # Now test with respect_handler_level set + + handler = TestHandler(support.Matcher()) + handler.setLevel(logging.CRITICAL) + listener = logging.handlers.QueueListener(self.queue, handler, + respect_handler_level=True) + listener.start() + try: + self.que_logger.warning(self.next_message()) + self.que_logger.error(self.next_message()) + self.que_logger.critical(self.next_message()) + finally: + listener.stop() + self.assertFalse(handler.matches(levelno=logging.WARNING, message='4')) + self.assertFalse(handler.matches(levelno=logging.ERROR, message='5')) + self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='6')) + handler.close() + + @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), + 'logging.handlers.QueueListener required for this test') + def test_queue_listener_with_StreamHandler(self): + # Test that traceback and stack-info only appends once (bpo-34334, bpo-46755). + listener = logging.handlers.QueueListener(self.queue, self.root_hdlr) + listener.start() + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + self.que_logger.exception(self.next_message(), exc_info=exc) + self.que_logger.error(self.next_message(), stack_info=True) + listener.stop() + self.assertEqual(self.stream.getvalue().strip().count('Traceback'), 1) + self.assertEqual(self.stream.getvalue().strip().count('Stack'), 1) + + @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), + 'logging.handlers.QueueListener required for this test') + def test_queue_listener_with_multiple_handlers(self): + # Test that queue handler format doesn't affect other handler formats (bpo-35726). + self.que_hdlr.setFormatter(self.root_formatter) + self.que_logger.addHandler(self.root_hdlr) + + listener = logging.handlers.QueueListener(self.queue, self.que_hdlr) + listener.start() + self.que_logger.error("error") + listener.stop() + self.assertEqual(self.stream.getvalue().strip(), "que -> ERROR: error") + +if hasattr(logging.handlers, 'QueueListener'): + import multiprocessing + from unittest.mock import patch + + @threading_helper.requires_working_threading() + class QueueListenerTest(BaseTest): + """ + Tests based on patch submitted for issue #27930. Ensure that + QueueListener handles all log messages. + """ + + repeat = 20 + + @staticmethod + def setup_and_log(log_queue, ident): + """ + Creates a logger with a QueueHandler that logs to a queue read by a + QueueListener. Starts the listener, logs five messages, and stops + the listener. + """ + logger = logging.getLogger('test_logger_with_id_%s' % ident) + logger.setLevel(logging.DEBUG) + handler = logging.handlers.QueueHandler(log_queue) + logger.addHandler(handler) + listener = logging.handlers.QueueListener(log_queue) + listener.start() + + logger.info('one') + logger.info('two') + logger.info('three') + logger.info('four') + logger.info('five') + + listener.stop() + logger.removeHandler(handler) + handler.close() + + @patch.object(logging.handlers.QueueListener, 'handle') + def test_handle_called_with_queue_queue(self, mock_handle): + for i in range(self.repeat): + log_queue = queue.Queue() + self.setup_and_log(log_queue, '%s_%s' % (self.id(), i)) + self.assertEqual(mock_handle.call_count, 5 * self.repeat, + 'correct number of handled log messages') + + @patch.object(logging.handlers.QueueListener, 'handle') + def test_handle_called_with_mp_queue(self, mock_handle): + # bpo-28668: The multiprocessing (mp) module is not functional + # when the mp.synchronize module cannot be imported. + support.skip_if_broken_multiprocessing_synchronize() + for i in range(self.repeat): + log_queue = multiprocessing.Queue() + self.setup_and_log(log_queue, '%s_%s' % (self.id(), i)) + log_queue.close() + log_queue.join_thread() + self.assertEqual(mock_handle.call_count, 5 * self.repeat, + 'correct number of handled log messages') + + @staticmethod + def get_all_from_queue(log_queue): + try: + while True: + yield log_queue.get_nowait() + except queue.Empty: + return [] + + def test_no_messages_in_queue_after_stop(self): + """ + Five messages are logged then the QueueListener is stopped. This + test then gets everything off the queue. Failure of this test + indicates that messages were not registered on the queue until + _after_ the QueueListener stopped. + """ + # bpo-28668: The multiprocessing (mp) module is not functional + # when the mp.synchronize module cannot be imported. + support.skip_if_broken_multiprocessing_synchronize() + for i in range(self.repeat): + queue = multiprocessing.Queue() + self.setup_and_log(queue, '%s_%s' %(self.id(), i)) + # time.sleep(1) + items = list(self.get_all_from_queue(queue)) + queue.close() + queue.join_thread() + + expected = [[], [logging.handlers.QueueListener._sentinel]] + self.assertIn(items, expected, + 'Found unexpected messages in queue: %s' % ( + [m.msg if isinstance(m, logging.LogRecord) + else m for m in items])) + + def test_calls_task_done_after_stop(self): + # Issue 36813: Make sure queue.join does not deadlock. + log_queue = queue.Queue() + listener = logging.handlers.QueueListener(log_queue) + listener.start() + listener.stop() + with self.assertRaises(ValueError): + # Make sure all tasks are done and .join won't block. + log_queue.task_done() + + +ZERO = datetime.timedelta(0) + +class UTC(datetime.tzinfo): + def utcoffset(self, dt): + return ZERO + + dst = utcoffset + + def tzname(self, dt): + return 'UTC' + +utc = UTC() + +class AssertErrorMessage: + + def assert_error_message(self, exception, message, *args, **kwargs): + try: + self.assertRaises((), *args, **kwargs) + except exception as e: + self.assertEqual(message, str(e)) + +class FormatterTest(unittest.TestCase, AssertErrorMessage): + def setUp(self): + self.common = { + 'name': 'formatter.test', + 'level': logging.DEBUG, + 'pathname': os.path.join('path', 'to', 'dummy.ext'), + 'lineno': 42, + 'exc_info': None, + 'func': None, + 'msg': 'Message with %d %s', + 'args': (2, 'placeholders'), + } + self.variants = { + 'custom': { + 'custom': 1234 + } + } + + def get_record(self, name=None): + result = dict(self.common) + if name is not None: + result.update(self.variants[name]) + return logging.makeLogRecord(result) + + def test_percent(self): + # Test %-formatting + r = self.get_record() + f = logging.Formatter('${%(message)s}') + self.assertEqual(f.format(r), '${Message with 2 placeholders}') + f = logging.Formatter('%(random)s') + self.assertRaises(ValueError, f.format, r) + self.assertFalse(f.usesTime()) + f = logging.Formatter('%(asctime)s') + self.assertTrue(f.usesTime()) + f = logging.Formatter('%(asctime)-15s') + self.assertTrue(f.usesTime()) + f = logging.Formatter('%(asctime)#15s') + self.assertTrue(f.usesTime()) + + def test_braces(self): + # Test {}-formatting + r = self.get_record() + f = logging.Formatter('$%{message}%$', style='{') + self.assertEqual(f.format(r), '$%Message with 2 placeholders%$') + f = logging.Formatter('{random}', style='{') + self.assertRaises(ValueError, f.format, r) + f = logging.Formatter("{message}", style='{') + self.assertFalse(f.usesTime()) + f = logging.Formatter('{asctime}', style='{') + self.assertTrue(f.usesTime()) + f = logging.Formatter('{asctime!s:15}', style='{') + self.assertTrue(f.usesTime()) + f = logging.Formatter('{asctime:15}', style='{') + self.assertTrue(f.usesTime()) + + def test_dollars(self): + # Test $-formatting + r = self.get_record() + f = logging.Formatter('${message}', style='$') + self.assertEqual(f.format(r), 'Message with 2 placeholders') + f = logging.Formatter('$message', style='$') + self.assertEqual(f.format(r), 'Message with 2 placeholders') + f = logging.Formatter('$$%${message}%$$', style='$') + self.assertEqual(f.format(r), '$%Message with 2 placeholders%$') + f = logging.Formatter('${random}', style='$') + self.assertRaises(ValueError, f.format, r) + self.assertFalse(f.usesTime()) + f = logging.Formatter('${asctime}', style='$') + self.assertTrue(f.usesTime()) + f = logging.Formatter('$asctime', style='$') + self.assertTrue(f.usesTime()) + f = logging.Formatter('${message}', style='$') + self.assertFalse(f.usesTime()) + f = logging.Formatter('${asctime}--', style='$') + self.assertTrue(f.usesTime()) + + # TODO: RustPython + @unittest.expectedFailure + def test_format_validate(self): + # Check correct formatting + # Percentage style + f = logging.Formatter("%(levelname)-15s - %(message) 5s - %(process)03d - %(module) - %(asctime)*.3s") + self.assertEqual(f._fmt, "%(levelname)-15s - %(message) 5s - %(process)03d - %(module) - %(asctime)*.3s") + f = logging.Formatter("%(asctime)*s - %(asctime)*.3s - %(process)-34.33o") + self.assertEqual(f._fmt, "%(asctime)*s - %(asctime)*.3s - %(process)-34.33o") + f = logging.Formatter("%(process)#+027.23X") + self.assertEqual(f._fmt, "%(process)#+027.23X") + f = logging.Formatter("%(foo)#.*g") + self.assertEqual(f._fmt, "%(foo)#.*g") + + # StrFormat Style + f = logging.Formatter("$%{message}%$ - {asctime!a:15} - {customfield['key']}", style="{") + self.assertEqual(f._fmt, "$%{message}%$ - {asctime!a:15} - {customfield['key']}") + f = logging.Formatter("{process:.2f} - {custom.f:.4f}", style="{") + self.assertEqual(f._fmt, "{process:.2f} - {custom.f:.4f}") + f = logging.Formatter("{customfield!s:#<30}", style="{") + self.assertEqual(f._fmt, "{customfield!s:#<30}") + f = logging.Formatter("{message!r}", style="{") + self.assertEqual(f._fmt, "{message!r}") + f = logging.Formatter("{message!s}", style="{") + self.assertEqual(f._fmt, "{message!s}") + f = logging.Formatter("{message!a}", style="{") + self.assertEqual(f._fmt, "{message!a}") + f = logging.Formatter("{process!r:4.2}", style="{") + self.assertEqual(f._fmt, "{process!r:4.2}") + f = logging.Formatter("{process!s:<#30,.12f}- {custom:=+#30,.1d} - {module:^30}", style="{") + self.assertEqual(f._fmt, "{process!s:<#30,.12f}- {custom:=+#30,.1d} - {module:^30}") + f = logging.Formatter("{process!s:{w},.{p}}", style="{") + self.assertEqual(f._fmt, "{process!s:{w},.{p}}") + f = logging.Formatter("{foo:12.{p}}", style="{") + self.assertEqual(f._fmt, "{foo:12.{p}}") + f = logging.Formatter("{foo:{w}.6}", style="{") + self.assertEqual(f._fmt, "{foo:{w}.6}") + f = logging.Formatter("{foo[0].bar[1].baz}", style="{") + self.assertEqual(f._fmt, "{foo[0].bar[1].baz}") + f = logging.Formatter("{foo[k1].bar[k2].baz}", style="{") + self.assertEqual(f._fmt, "{foo[k1].bar[k2].baz}") + f = logging.Formatter("{12[k1].bar[k2].baz}", style="{") + self.assertEqual(f._fmt, "{12[k1].bar[k2].baz}") + + # Dollar style + f = logging.Formatter("${asctime} - $message", style="$") + self.assertEqual(f._fmt, "${asctime} - $message") + f = logging.Formatter("$bar $$", style="$") + self.assertEqual(f._fmt, "$bar $$") + f = logging.Formatter("$bar $$$$", style="$") + self.assertEqual(f._fmt, "$bar $$$$") # this would print two $($$) + + # Testing when ValueError being raised from incorrect format + # Percentage Style + self.assertRaises(ValueError, logging.Formatter, "%(asctime)Z") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)b") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)*") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)*3s") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)_") + self.assertRaises(ValueError, logging.Formatter, '{asctime}') + self.assertRaises(ValueError, logging.Formatter, '${message}') + self.assertRaises(ValueError, logging.Formatter, '%(foo)#12.3*f') # with both * and decimal number as precision + self.assertRaises(ValueError, logging.Formatter, '%(foo)0*.8*f') + + # StrFormat Style + # Testing failure for '-' in field name + self.assert_error_message( + ValueError, + "invalid format: invalid field name/expression: 'name-thing'", + logging.Formatter, "{name-thing}", style="{" + ) + # Testing failure for style mismatch + self.assert_error_message( + ValueError, + "invalid format: no fields", + logging.Formatter, '%(asctime)s', style='{' + ) + # Testing failure for invalid conversion + self.assert_error_message( + ValueError, + "invalid conversion: 'Z'" + ) + self.assertRaises(ValueError, logging.Formatter, '{asctime!s:#30,15f}', style='{') + self.assert_error_message( + ValueError, + "invalid format: expected ':' after conversion specifier", + logging.Formatter, '{asctime!aa:15}', style='{' + ) + # Testing failure for invalid spec + self.assert_error_message( + ValueError, + "invalid format: bad specifier: '.2ff'", + logging.Formatter, '{process:.2ff}', style='{' + ) + self.assertRaises(ValueError, logging.Formatter, '{process:.2Z}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{process!s:<##30,12g}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{process!s:<#30#,12g}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{process!s:{{w}},{{p}}}', style='{') + # Testing failure for mismatch braces + self.assert_error_message( + ValueError, + "invalid format: expected '}' before end of string", + logging.Formatter, '{process', style='{' + ) + self.assert_error_message( + ValueError, + "invalid format: Single '}' encountered in format string", + logging.Formatter, 'process}', style='{' + ) + self.assertRaises(ValueError, logging.Formatter, '{{foo!r:4.2}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{{foo!r:4.2}}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo/bar}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo:{{w}}.{{p}}}}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!X:{{w}}.{{p}}}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!a:random}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!a:ran{dom}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!a:ran{d}om}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo.!a:d}', style='{') + + # Dollar style + # Testing failure for mismatch bare $ + self.assert_error_message( + ValueError, + "invalid format: bare \'$\' not allowed", + logging.Formatter, '$bar $$$', style='$' + ) + self.assert_error_message( + ValueError, + "invalid format: bare \'$\' not allowed", + logging.Formatter, 'bar $', style='$' + ) + self.assert_error_message( + ValueError, + "invalid format: bare \'$\' not allowed", + logging.Formatter, 'foo $.', style='$' + ) + # Testing failure for mismatch style + self.assert_error_message( + ValueError, + "invalid format: no fields", + logging.Formatter, '{asctime}', style='$' + ) + self.assertRaises(ValueError, logging.Formatter, '%(asctime)s', style='$') + + # Testing failure for incorrect fields + self.assert_error_message( + ValueError, + "invalid format: no fields", + logging.Formatter, 'foo', style='$' + ) + self.assertRaises(ValueError, logging.Formatter, '${asctime', style='$') + + def test_defaults_parameter(self): + fmts = ['%(custom)s %(message)s', '{custom} {message}', '$custom $message'] + styles = ['%', '{', '$'] + for fmt, style in zip(fmts, styles): + f = logging.Formatter(fmt, style=style, defaults={'custom': 'Default'}) + r = self.get_record() + self.assertEqual(f.format(r), 'Default Message with 2 placeholders') + r = self.get_record("custom") + self.assertEqual(f.format(r), '1234 Message with 2 placeholders') + + # Without default + f = logging.Formatter(fmt, style=style) + r = self.get_record() + self.assertRaises(ValueError, f.format, r) + + # Non-existing default is ignored + f = logging.Formatter(fmt, style=style, defaults={'Non-existing': 'Default'}) + r = self.get_record("custom") + self.assertEqual(f.format(r), '1234 Message with 2 placeholders') + + def test_invalid_style(self): + self.assertRaises(ValueError, logging.Formatter, None, None, 'x') + + # TODO: RustPython + @unittest.expectedFailure + def test_time(self): + r = self.get_record() + dt = datetime.datetime(1993, 4, 21, 8, 3, 0, 0, utc) + # We use None to indicate we want the local timezone + # We're essentially converting a UTC time to local time + r.created = time.mktime(dt.astimezone(None).timetuple()) + r.msecs = 123 + f = logging.Formatter('%(asctime)s %(message)s') + f.converter = time.gmtime + self.assertEqual(f.formatTime(r), '1993-04-21 08:03:00,123') + self.assertEqual(f.formatTime(r, '%Y:%d'), '1993:21') + f.format(r) + self.assertEqual(r.asctime, '1993-04-21 08:03:00,123') + + # TODO: RustPython + @unittest.expectedFailure + def test_default_msec_format_none(self): + class NoMsecFormatter(logging.Formatter): + default_msec_format = None + default_time_format = '%d/%m/%Y %H:%M:%S' + + r = self.get_record() + dt = datetime.datetime(1993, 4, 21, 8, 3, 0, 123, utc) + r.created = time.mktime(dt.astimezone(None).timetuple()) + f = NoMsecFormatter() + f.converter = time.gmtime + self.assertEqual(f.formatTime(r), '21/04/1993 08:03:00') + + def test_issue_89047(self): + f = logging.Formatter(fmt='{asctime}.{msecs:03.0f} {message}', style='{', datefmt="%Y-%m-%d %H:%M:%S") + for i in range(2500): + time.sleep(0.0004) + r = logging.makeLogRecord({'msg': 'Message %d' % (i + 1)}) + s = f.format(r) + self.assertNotIn('.1000', s) + + +class TestBufferingFormatter(logging.BufferingFormatter): + def formatHeader(self, records): + return '[(%d)' % len(records) + + def formatFooter(self, records): + return '(%d)]' % len(records) + +class BufferingFormatterTest(unittest.TestCase): + def setUp(self): + self.records = [ + logging.makeLogRecord({'msg': 'one'}), + logging.makeLogRecord({'msg': 'two'}), + ] + + def test_default(self): + f = logging.BufferingFormatter() + self.assertEqual('', f.format([])) + self.assertEqual('onetwo', f.format(self.records)) + + def test_custom(self): + f = TestBufferingFormatter() + self.assertEqual('[(2)onetwo(2)]', f.format(self.records)) + lf = logging.Formatter('<%(message)s>') + f = TestBufferingFormatter(lf) + self.assertEqual('[(2)(2)]', f.format(self.records)) + +class ExceptionTest(BaseTest): + def test_formatting(self): + r = self.root_logger + h = RecordingHandler() + r.addHandler(h) + try: + raise RuntimeError('deliberate mistake') + except: + logging.exception('failed', stack_info=True) + r.removeHandler(h) + h.close() + r = h.records[0] + self.assertTrue(r.exc_text.startswith('Traceback (most recent ' + 'call last):\n')) + self.assertTrue(r.exc_text.endswith('\nRuntimeError: ' + 'deliberate mistake')) + self.assertTrue(r.stack_info.startswith('Stack (most recent ' + 'call last):\n')) + self.assertTrue(r.stack_info.endswith('logging.exception(\'failed\', ' + 'stack_info=True)')) + + +class LastResortTest(BaseTest): + def test_last_resort(self): + # Test the last resort handler + root = self.root_logger + root.removeHandler(self.root_hdlr) + old_lastresort = logging.lastResort + old_raise_exceptions = logging.raiseExceptions + + try: + with support.captured_stderr() as stderr: + root.debug('This should not appear') + self.assertEqual(stderr.getvalue(), '') + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), 'Final chance!\n') + + # No handlers and no last resort, so 'No handlers' message + logging.lastResort = None + with support.captured_stderr() as stderr: + root.warning('Final chance!') + msg = 'No handlers could be found for logger "root"\n' + self.assertEqual(stderr.getvalue(), msg) + + # 'No handlers' message only printed once + with support.captured_stderr() as stderr: + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), '') + + # If raiseExceptions is False, no message is printed + root.manager.emittedNoHandlerWarning = False + logging.raiseExceptions = False + with support.captured_stderr() as stderr: + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), '') + finally: + root.addHandler(self.root_hdlr) + logging.lastResort = old_lastresort + logging.raiseExceptions = old_raise_exceptions + + +class FakeHandler: + + def __init__(self, identifier, called): + for method in ('acquire', 'flush', 'close', 'release'): + setattr(self, method, self.record_call(identifier, method, called)) + + def record_call(self, identifier, method_name, called): + def inner(): + called.append('{} - {}'.format(identifier, method_name)) + return inner + + +class RecordingHandler(logging.NullHandler): + + def __init__(self, *args, **kwargs): + super(RecordingHandler, self).__init__(*args, **kwargs) + self.records = [] + + def handle(self, record): + """Keep track of all the emitted records.""" + self.records.append(record) + + +class ShutdownTest(BaseTest): + + """Test suite for the shutdown method.""" + + def setUp(self): + super(ShutdownTest, self).setUp() + self.called = [] + + raise_exceptions = logging.raiseExceptions + self.addCleanup(setattr, logging, 'raiseExceptions', raise_exceptions) + + def raise_error(self, error): + def inner(): + raise error() + return inner + + def test_no_failure(self): + # create some fake handlers + handler0 = FakeHandler(0, self.called) + handler1 = FakeHandler(1, self.called) + handler2 = FakeHandler(2, self.called) + + # create live weakref to those handlers + handlers = map(logging.weakref.ref, [handler0, handler1, handler2]) + + logging.shutdown(handlerList=list(handlers)) + + expected = ['2 - acquire', '2 - flush', '2 - close', '2 - release', + '1 - acquire', '1 - flush', '1 - close', '1 - release', + '0 - acquire', '0 - flush', '0 - close', '0 - release'] + self.assertEqual(expected, self.called) + + def _test_with_failure_in_method(self, method, error): + handler = FakeHandler(0, self.called) + setattr(handler, method, self.raise_error(error)) + handlers = [logging.weakref.ref(handler)] + + logging.shutdown(handlerList=list(handlers)) + + self.assertEqual('0 - release', self.called[-1]) + + def test_with_ioerror_in_acquire(self): + self._test_with_failure_in_method('acquire', OSError) + + def test_with_ioerror_in_flush(self): + self._test_with_failure_in_method('flush', OSError) + + def test_with_ioerror_in_close(self): + self._test_with_failure_in_method('close', OSError) + + def test_with_valueerror_in_acquire(self): + self._test_with_failure_in_method('acquire', ValueError) + + def test_with_valueerror_in_flush(self): + self._test_with_failure_in_method('flush', ValueError) + + def test_with_valueerror_in_close(self): + self._test_with_failure_in_method('close', ValueError) + + def test_with_other_error_in_acquire_without_raise(self): + logging.raiseExceptions = False + self._test_with_failure_in_method('acquire', IndexError) + + def test_with_other_error_in_flush_without_raise(self): + logging.raiseExceptions = False + self._test_with_failure_in_method('flush', IndexError) + + def test_with_other_error_in_close_without_raise(self): + logging.raiseExceptions = False + self._test_with_failure_in_method('close', IndexError) + + def test_with_other_error_in_acquire_with_raise(self): + logging.raiseExceptions = True + self.assertRaises(IndexError, self._test_with_failure_in_method, + 'acquire', IndexError) + + def test_with_other_error_in_flush_with_raise(self): + logging.raiseExceptions = True + self.assertRaises(IndexError, self._test_with_failure_in_method, + 'flush', IndexError) + + def test_with_other_error_in_close_with_raise(self): + logging.raiseExceptions = True + self.assertRaises(IndexError, self._test_with_failure_in_method, + 'close', IndexError) + + +class ModuleLevelMiscTest(BaseTest): + + """Test suite for some module level methods.""" + + def test_disable(self): + old_disable = logging.root.manager.disable + # confirm our assumptions are correct + self.assertEqual(old_disable, 0) + self.addCleanup(logging.disable, old_disable) + + logging.disable(83) + self.assertEqual(logging.root.manager.disable, 83) + + self.assertRaises(ValueError, logging.disable, "doesnotexists") + + class _NotAnIntOrString: + pass + + self.assertRaises(TypeError, logging.disable, _NotAnIntOrString()) + + logging.disable("WARN") + + # test the default value introduced in 3.7 + # (Issue #28524) + logging.disable() + self.assertEqual(logging.root.manager.disable, logging.CRITICAL) + + def _test_log(self, method, level=None): + called = [] + support.patch(self, logging, 'basicConfig', + lambda *a, **kw: called.append((a, kw))) + + recording = RecordingHandler() + logging.root.addHandler(recording) + + log_method = getattr(logging, method) + if level is not None: + log_method(level, "test me: %r", recording) + else: + log_method("test me: %r", recording) + + self.assertEqual(len(recording.records), 1) + record = recording.records[0] + self.assertEqual(record.getMessage(), "test me: %r" % recording) + + expected_level = level if level is not None else getattr(logging, method.upper()) + self.assertEqual(record.levelno, expected_level) + + # basicConfig was not called! + self.assertEqual(called, []) + + def test_log(self): + self._test_log('log', logging.ERROR) + + def test_debug(self): + self._test_log('debug') + + def test_info(self): + self._test_log('info') + + def test_warning(self): + self._test_log('warning') + + def test_error(self): + self._test_log('error') + + def test_critical(self): + self._test_log('critical') + + def test_set_logger_class(self): + self.assertRaises(TypeError, logging.setLoggerClass, object) + + class MyLogger(logging.Logger): + pass + + logging.setLoggerClass(MyLogger) + self.assertEqual(logging.getLoggerClass(), MyLogger) + + logging.setLoggerClass(logging.Logger) + self.assertEqual(logging.getLoggerClass(), logging.Logger) + + def test_subclass_logger_cache(self): + # bpo-37258 + message = [] + + class MyLogger(logging.getLoggerClass()): + def __init__(self, name='MyLogger', level=logging.NOTSET): + super().__init__(name, level) + message.append('initialized') + + logging.setLoggerClass(MyLogger) + logger = logging.getLogger('just_some_logger') + self.assertEqual(message, ['initialized']) + stream = io.StringIO() + h = logging.StreamHandler(stream) + logger.addHandler(h) + try: + logger.setLevel(logging.DEBUG) + logger.debug("hello") + self.assertEqual(stream.getvalue().strip(), "hello") + + stream.truncate(0) + stream.seek(0) + + logger.setLevel(logging.INFO) + logger.debug("hello") + self.assertEqual(stream.getvalue(), "") + finally: + logger.removeHandler(h) + h.close() + logging.setLoggerClass(logging.Logger) + + # TODO: RustPython + @unittest.expectedFailure + def test_logging_at_shutdown(self): + # bpo-20037: Doing text I/O late at interpreter shutdown must not crash + code = textwrap.dedent(""" + import logging + + class A: + def __del__(self): + try: + raise ValueError("some error") + except Exception: + logging.exception("exception in __del__") + + a = A() + """) + rc, out, err = assert_python_ok("-c", code) + err = err.decode() + self.assertIn("exception in __del__", err) + self.assertIn("ValueError: some error", err) + + # TODO: RustPython + @unittest.expectedFailure + def test_logging_at_shutdown_open(self): + # bpo-26789: FileHandler keeps a reference to the builtin open() + # function to be able to open or reopen the file during Python + # finalization. + filename = os_helper.TESTFN + self.addCleanup(os_helper.unlink, filename) + + code = textwrap.dedent(f""" + import builtins + import logging + + class A: + def __del__(self): + logging.error("log in __del__") + + # basicConfig() opens the file, but logging.shutdown() closes + # it at Python exit. When A.__del__() is called, + # FileHandler._open() must be called again to re-open the file. + logging.basicConfig(filename={filename!r}, encoding="utf-8") + + a = A() + + # Simulate the Python finalization which removes the builtin + # open() function. + del builtins.open + """) + assert_python_ok("-c", code) + + with open(filename, encoding="utf-8") as fp: + self.assertEqual(fp.read().rstrip(), "ERROR:root:log in __del__") + + def test_recursion_error(self): + # Issue 36272 + code = textwrap.dedent(""" + import logging + + def rec(): + logging.error("foo") + rec() + + rec() + """) + rc, out, err = assert_python_failure("-c", code) + err = err.decode() + self.assertNotIn("Cannot recover from stack overflow.", err) + self.assertEqual(rc, 1) + + def test_get_level_names_mapping(self): + mapping = logging.getLevelNamesMapping() + self.assertEqual(logging._nameToLevel, mapping) # value is equivalent + self.assertIsNot(logging._nameToLevel, mapping) # but not the internal data + new_mapping = logging.getLevelNamesMapping() # another call -> another copy + self.assertIsNot(mapping, new_mapping) # verify not the same object as before + self.assertEqual(mapping, new_mapping) # but equivalent in value + + +class LogRecordTest(BaseTest): + def test_str_rep(self): + r = logging.makeLogRecord({}) + s = str(r) + self.assertTrue(s.startswith('')) + + def test_dict_arg(self): + h = RecordingHandler() + r = logging.getLogger() + r.addHandler(h) + d = {'less' : 'more' } + logging.warning('less is %(less)s', d) + self.assertIs(h.records[0].args, d) + self.assertEqual(h.records[0].message, 'less is more') + r.removeHandler(h) + h.close() + + @staticmethod # pickled as target of child process in the following test + def _extract_logrecord_process_name(key, logMultiprocessing, conn=None): + prev_logMultiprocessing = logging.logMultiprocessing + logging.logMultiprocessing = logMultiprocessing + try: + import multiprocessing as mp + name = mp.current_process().name + + r1 = logging.makeLogRecord({'msg': f'msg1_{key}'}) + + # https://bugs.python.org/issue45128 + with support.swap_item(sys.modules, 'multiprocessing', None): + r2 = logging.makeLogRecord({'msg': f'msg2_{key}'}) + + results = {'processName' : name, + 'r1.processName': r1.processName, + 'r2.processName': r2.processName, + } + finally: + logging.logMultiprocessing = prev_logMultiprocessing + if conn: + conn.send(results) + else: + return results + + def test_multiprocessing(self): + support.skip_if_broken_multiprocessing_synchronize() + multiprocessing_imported = 'multiprocessing' in sys.modules + try: + # logMultiprocessing is True by default + self.assertEqual(logging.logMultiprocessing, True) + + LOG_MULTI_PROCESSING = True + # When logMultiprocessing == True: + # In the main process processName = 'MainProcess' + r = logging.makeLogRecord({}) + self.assertEqual(r.processName, 'MainProcess') + + results = self._extract_logrecord_process_name(1, LOG_MULTI_PROCESSING) + self.assertEqual('MainProcess', results['processName']) + self.assertEqual('MainProcess', results['r1.processName']) + self.assertEqual('MainProcess', results['r2.processName']) + + # In other processes, processName is correct when multiprocessing in imported, + # but it is (incorrectly) defaulted to 'MainProcess' otherwise (bpo-38762). + import multiprocessing + parent_conn, child_conn = multiprocessing.Pipe() + p = multiprocessing.Process( + target=self._extract_logrecord_process_name, + args=(2, LOG_MULTI_PROCESSING, child_conn,) + ) + p.start() + results = parent_conn.recv() + self.assertNotEqual('MainProcess', results['processName']) + self.assertEqual(results['processName'], results['r1.processName']) + self.assertEqual('MainProcess', results['r2.processName']) + p.join() + + finally: + if multiprocessing_imported: + import multiprocessing + + def test_optional(self): + NONE = self.assertIsNone + NOT_NONE = self.assertIsNotNone + + r = logging.makeLogRecord({}) + NOT_NONE(r.thread) + NOT_NONE(r.threadName) + NOT_NONE(r.process) + NOT_NONE(r.processName) + NONE(r.taskName) + log_threads = logging.logThreads + log_processes = logging.logProcesses + log_multiprocessing = logging.logMultiprocessing + log_asyncio_tasks = logging.logAsyncioTasks + try: + logging.logThreads = False + logging.logProcesses = False + logging.logMultiprocessing = False + logging.logAsyncioTasks = False + r = logging.makeLogRecord({}) + + NONE(r.thread) + NONE(r.threadName) + NONE(r.process) + NONE(r.processName) + NONE(r.taskName) + finally: + logging.logThreads = log_threads + logging.logProcesses = log_processes + logging.logMultiprocessing = log_multiprocessing + logging.logAsyncioTasks = log_asyncio_tasks + + async def _make_record_async(self, assertion): + r = logging.makeLogRecord({}) + assertion(r.taskName) + + @support.requires_working_socket() + def test_taskName_with_asyncio_imported(self): + try: + make_record = self._make_record_async + with asyncio.Runner() as runner: + logging.logAsyncioTasks = True + runner.run(make_record(self.assertIsNotNone)) + logging.logAsyncioTasks = False + runner.run(make_record(self.assertIsNone)) + finally: + asyncio.set_event_loop_policy(None) + + @support.requires_working_socket() + def test_taskName_without_asyncio_imported(self): + try: + make_record = self._make_record_async + with asyncio.Runner() as runner, support.swap_item(sys.modules, 'asyncio', None): + logging.logAsyncioTasks = True + runner.run(make_record(self.assertIsNone)) + logging.logAsyncioTasks = False + runner.run(make_record(self.assertIsNone)) + finally: + asyncio.set_event_loop_policy(None) + + +class BasicConfigTest(unittest.TestCase): + + """Test suite for logging.basicConfig.""" + + def setUp(self): + super(BasicConfigTest, self).setUp() + self.handlers = logging.root.handlers + self.saved_handlers = logging._handlers.copy() + self.saved_handler_list = logging._handlerList[:] + self.original_logging_level = logging.root.level + self.addCleanup(self.cleanup) + logging.root.handlers = [] + + def tearDown(self): + for h in logging.root.handlers[:]: + logging.root.removeHandler(h) + h.close() + super(BasicConfigTest, self).tearDown() + + def cleanup(self): + setattr(logging.root, 'handlers', self.handlers) + logging._handlers.clear() + logging._handlers.update(self.saved_handlers) + logging._handlerList[:] = self.saved_handler_list + logging.root.setLevel(self.original_logging_level) + + def test_no_kwargs(self): + logging.basicConfig() + + # handler defaults to a StreamHandler to sys.stderr + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.StreamHandler) + self.assertEqual(handler.stream, sys.stderr) + + formatter = handler.formatter + # format defaults to logging.BASIC_FORMAT + self.assertEqual(formatter._style._fmt, logging.BASIC_FORMAT) + # datefmt defaults to None + self.assertIsNone(formatter.datefmt) + # style defaults to % + self.assertIsInstance(formatter._style, logging.PercentStyle) + + # level is not explicitly set + self.assertEqual(logging.root.level, self.original_logging_level) + + def test_strformatstyle(self): + with support.captured_stdout() as output: + logging.basicConfig(stream=sys.stdout, style="{") + logging.error("Log an error") + sys.stdout.seek(0) + self.assertEqual(output.getvalue().strip(), + "ERROR:root:Log an error") + + def test_stringtemplatestyle(self): + with support.captured_stdout() as output: + logging.basicConfig(stream=sys.stdout, style="$") + logging.error("Log an error") + sys.stdout.seek(0) + self.assertEqual(output.getvalue().strip(), + "ERROR:root:Log an error") + + def test_filename(self): + + def cleanup(h1, h2, fn): + h1.close() + h2.close() + os.remove(fn) + + logging.basicConfig(filename='test.log', encoding='utf-8') + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + + expected = logging.FileHandler('test.log', 'a', encoding='utf-8') + self.assertEqual(handler.stream.mode, expected.stream.mode) + self.assertEqual(handler.stream.name, expected.stream.name) + self.addCleanup(cleanup, handler, expected, 'test.log') + + def test_filemode(self): + + def cleanup(h1, h2, fn): + h1.close() + h2.close() + os.remove(fn) + + logging.basicConfig(filename='test.log', filemode='wb') + + handler = logging.root.handlers[0] + expected = logging.FileHandler('test.log', 'wb') + self.assertEqual(handler.stream.mode, expected.stream.mode) + self.addCleanup(cleanup, handler, expected, 'test.log') + + def test_stream(self): + stream = io.StringIO() + self.addCleanup(stream.close) + logging.basicConfig(stream=stream) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.StreamHandler) + self.assertEqual(handler.stream, stream) + + def test_format(self): + logging.basicConfig(format='%(asctime)s - %(message)s') + + formatter = logging.root.handlers[0].formatter + self.assertEqual(formatter._style._fmt, '%(asctime)s - %(message)s') + + def test_datefmt(self): + logging.basicConfig(datefmt='bar') + + formatter = logging.root.handlers[0].formatter + self.assertEqual(formatter.datefmt, 'bar') + + def test_style(self): + logging.basicConfig(style='$') + + formatter = logging.root.handlers[0].formatter + self.assertIsInstance(formatter._style, logging.StringTemplateStyle) + + def test_level(self): + old_level = logging.root.level + self.addCleanup(logging.root.setLevel, old_level) + + logging.basicConfig(level=57) + self.assertEqual(logging.root.level, 57) + # Test that second call has no effect + logging.basicConfig(level=58) + self.assertEqual(logging.root.level, 57) + + def test_incompatible(self): + assertRaises = self.assertRaises + handlers = [logging.StreamHandler()] + stream = sys.stderr + assertRaises(ValueError, logging.basicConfig, filename='test.log', + stream=stream) + assertRaises(ValueError, logging.basicConfig, filename='test.log', + handlers=handlers) + assertRaises(ValueError, logging.basicConfig, stream=stream, + handlers=handlers) + # Issue 23207: test for invalid kwargs + assertRaises(ValueError, logging.basicConfig, loglevel=logging.INFO) + # Should pop both filename and filemode even if filename is None + logging.basicConfig(filename=None, filemode='a') + + def test_handlers(self): + handlers = [ + logging.StreamHandler(), + logging.StreamHandler(sys.stdout), + logging.StreamHandler(), + ] + f = logging.Formatter() + handlers[2].setFormatter(f) + logging.basicConfig(handlers=handlers) + self.assertIs(handlers[0], logging.root.handlers[0]) + self.assertIs(handlers[1], logging.root.handlers[1]) + self.assertIs(handlers[2], logging.root.handlers[2]) + self.assertIsNotNone(handlers[0].formatter) + self.assertIsNotNone(handlers[1].formatter) + self.assertIs(handlers[2].formatter, f) + self.assertIs(handlers[0].formatter, handlers[1].formatter) + + def test_force(self): + old_string_io = io.StringIO() + new_string_io = io.StringIO() + old_handlers = [logging.StreamHandler(old_string_io)] + new_handlers = [logging.StreamHandler(new_string_io)] + logging.basicConfig(level=logging.WARNING, handlers=old_handlers) + logging.warning('warn') + logging.info('info') + logging.debug('debug') + self.assertEqual(len(logging.root.handlers), 1) + logging.basicConfig(level=logging.INFO, handlers=new_handlers, + force=True) + logging.warning('warn') + logging.info('info') + logging.debug('debug') + self.assertEqual(len(logging.root.handlers), 1) + self.assertEqual(old_string_io.getvalue().strip(), + 'WARNING:root:warn') + self.assertEqual(new_string_io.getvalue().strip(), + 'WARNING:root:warn\nINFO:root:info') + + def test_encoding(self): + try: + encoding = 'utf-8' + logging.basicConfig(filename='test.log', encoding=encoding, + errors='strict', + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + logging.debug('The Øresund Bridge joins Copenhagen to Malmö') + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + self.assertEqual(data, + 'The Øresund Bridge joins Copenhagen to Malmö') + + def test_encoding_errors(self): + try: + encoding = 'ascii' + logging.basicConfig(filename='test.log', encoding=encoding, + errors='ignore', + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + logging.debug('The Øresund Bridge joins Copenhagen to Malmö') + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + self.assertEqual(data, 'The resund Bridge joins Copenhagen to Malm') + + def test_encoding_errors_default(self): + try: + encoding = 'ascii' + logging.basicConfig(filename='test.log', encoding=encoding, + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + self.assertEqual(handler.errors, 'backslashreplace') + logging.debug('😂: ☃️: The Øresund Bridge joins Copenhagen to Malmö') + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + self.assertEqual(data, r'\U0001f602: \u2603\ufe0f: The \xd8resund ' + r'Bridge joins Copenhagen to Malm\xf6') + + # TODO: RustPython + @unittest.expectedFailure + def test_encoding_errors_none(self): + # Specifying None should behave as 'strict' + try: + encoding = 'ascii' + logging.basicConfig(filename='test.log', encoding=encoding, + errors=None, + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + self.assertIsNone(handler.errors) + + message = [] + + def dummy_handle_error(record): + message.append(str(sys.exception())) + + handler.handleError = dummy_handle_error + logging.debug('The Øresund Bridge joins Copenhagen to Malmö') + self.assertTrue(message) + self.assertIn("'ascii' codec can't encode " + "character '\\xd8' in position 4:", message[0]) + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + # didn't write anything due to the encoding error + self.assertEqual(data, r'') + + @support.requires_working_socket() + def test_log_taskName(self): + async def log_record(): + logging.warning('hello world') + + handler = None + log_filename = make_temp_file('.log', 'test-logging-taskname-') + self.addCleanup(os.remove, log_filename) + try: + encoding = 'utf-8' + logging.basicConfig(filename=log_filename, errors='strict', + encoding=encoding, level=logging.WARNING, + format='%(taskName)s - %(message)s') + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + + with asyncio.Runner(debug=True) as runner: + logging.logAsyncioTasks = True + runner.run(log_record()) + with open(log_filename, encoding='utf-8') as f: + data = f.read().strip() + self.assertRegex(data, r'Task-\d+ - hello world') + finally: + asyncio.set_event_loop_policy(None) + if handler: + handler.close() + + + def _test_log(self, method, level=None): + # logging.root has no handlers so basicConfig should be called + called = [] + + old_basic_config = logging.basicConfig + def my_basic_config(*a, **kw): + old_basic_config() + old_level = logging.root.level + logging.root.setLevel(100) # avoid having messages in stderr + self.addCleanup(logging.root.setLevel, old_level) + called.append((a, kw)) + + support.patch(self, logging, 'basicConfig', my_basic_config) + + log_method = getattr(logging, method) + if level is not None: + log_method(level, "test me") + else: + log_method("test me") + + # basicConfig was called with no arguments + self.assertEqual(called, [((), {})]) + + def test_log(self): + self._test_log('log', logging.WARNING) + + def test_debug(self): + self._test_log('debug') + + def test_info(self): + self._test_log('info') + + def test_warning(self): + self._test_log('warning') + + def test_error(self): + self._test_log('error') + + def test_critical(self): + self._test_log('critical') + + +class LoggerAdapterTest(unittest.TestCase): + def setUp(self): + super(LoggerAdapterTest, self).setUp() + old_handler_list = logging._handlerList[:] + + self.recording = RecordingHandler() + self.logger = logging.root + self.logger.addHandler(self.recording) + self.addCleanup(self.logger.removeHandler, self.recording) + self.addCleanup(self.recording.close) + + def cleanup(): + logging._handlerList[:] = old_handler_list + + self.addCleanup(cleanup) + self.addCleanup(logging.shutdown) + self.adapter = logging.LoggerAdapter(logger=self.logger, extra=None) + + def test_exception(self): + msg = 'testing exception: %r' + exc = None + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + self.adapter.exception(msg, self.recording) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.ERROR) + self.assertEqual(record.msg, msg) + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + + def test_exception_excinfo(self): + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + + self.adapter.exception('exc_info test', exc_info=exc) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + + def test_critical(self): + msg = 'critical test! %r' + self.adapter.critical(msg, self.recording) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.CRITICAL) + self.assertEqual(record.msg, msg) + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.funcName, 'test_critical') + + def test_is_enabled_for(self): + old_disable = self.adapter.logger.manager.disable + self.adapter.logger.manager.disable = 33 + self.addCleanup(setattr, self.adapter.logger.manager, 'disable', + old_disable) + self.assertFalse(self.adapter.isEnabledFor(32)) + + def test_has_handlers(self): + self.assertTrue(self.adapter.hasHandlers()) + + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + + self.assertFalse(self.logger.hasHandlers()) + self.assertFalse(self.adapter.hasHandlers()) + + def test_nested(self): + msg = 'Adapters can be nested, yo.' + adapter = PrefixAdapter(logger=self.logger, extra=None) + adapter_adapter = PrefixAdapter(logger=adapter, extra=None) + adapter_adapter.prefix = 'AdapterAdapter' + self.assertEqual(repr(adapter), repr(adapter_adapter)) + adapter_adapter.log(logging.CRITICAL, msg, self.recording) + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.CRITICAL) + self.assertEqual(record.msg, f"Adapter AdapterAdapter {msg}") + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.funcName, 'test_nested') + orig_manager = adapter_adapter.manager + self.assertIs(adapter.manager, orig_manager) + self.assertIs(self.logger.manager, orig_manager) + temp_manager = object() + try: + adapter_adapter.manager = temp_manager + self.assertIs(adapter_adapter.manager, temp_manager) + self.assertIs(adapter.manager, temp_manager) + self.assertIs(self.logger.manager, temp_manager) + finally: + adapter_adapter.manager = orig_manager + self.assertIs(adapter_adapter.manager, orig_manager) + self.assertIs(adapter.manager, orig_manager) + self.assertIs(self.logger.manager, orig_manager) + + def test_styled_adapter(self): + # Test an example from the Cookbook. + records = self.recording.records + adapter = StyleAdapter(self.logger) + adapter.warning('Hello, {}!', 'world') + self.assertEqual(str(records[-1].msg), 'Hello, world!') + self.assertEqual(records[-1].funcName, 'test_styled_adapter') + adapter.log(logging.WARNING, 'Goodbye {}.', 'world') + self.assertEqual(str(records[-1].msg), 'Goodbye world.') + self.assertEqual(records[-1].funcName, 'test_styled_adapter') + + def test_nested_styled_adapter(self): + records = self.recording.records + adapter = PrefixAdapter(self.logger) + adapter.prefix = '{}' + adapter2 = StyleAdapter(adapter) + adapter2.warning('Hello, {}!', 'world') + self.assertEqual(str(records[-1].msg), '{} Hello, world!') + self.assertEqual(records[-1].funcName, 'test_nested_styled_adapter') + adapter2.log(logging.WARNING, 'Goodbye {}.', 'world') + self.assertEqual(str(records[-1].msg), '{} Goodbye world.') + self.assertEqual(records[-1].funcName, 'test_nested_styled_adapter') + + def test_find_caller_with_stacklevel(self): + the_level = 1 + trigger = self.adapter.warning + + def innermost(): + trigger('test', stacklevel=the_level) + + def inner(): + innermost() + + def outer(): + inner() + + records = self.recording.records + outer() + self.assertEqual(records[-1].funcName, 'innermost') + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'inner') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'outer') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'test_find_caller_with_stacklevel') + self.assertGreater(records[-1].lineno, lineno) + + def test_extra_in_records(self): + self.adapter = logging.LoggerAdapter(logger=self.logger, + extra={'foo': '1'}) + + self.adapter.critical('foo should be here') + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertTrue(hasattr(record, 'foo')) + self.assertEqual(record.foo, '1') + + def test_extra_not_merged_by_default(self): + self.adapter.critical('foo should NOT be here', extra={'foo': 'nope'}) + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertFalse(hasattr(record, 'foo')) + + +class PrefixAdapter(logging.LoggerAdapter): + prefix = 'Adapter' + + def process(self, msg, kwargs): + return f"{self.prefix} {msg}", kwargs + + +class Message: + def __init__(self, fmt, args): + self.fmt = fmt + self.args = args + + def __str__(self): + return self.fmt.format(*self.args) + + +class StyleAdapter(logging.LoggerAdapter): + def log(self, level, msg, /, *args, stacklevel=1, **kwargs): + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, Message(msg, args), **kwargs, + stacklevel=stacklevel+1) + + +class LoggerTest(BaseTest, AssertErrorMessage): + + def setUp(self): + super(LoggerTest, self).setUp() + self.recording = RecordingHandler() + self.logger = logging.Logger(name='blah') + self.logger.addHandler(self.recording) + self.addCleanup(self.logger.removeHandler, self.recording) + self.addCleanup(self.recording.close) + self.addCleanup(logging.shutdown) + + def test_set_invalid_level(self): + self.assert_error_message( + TypeError, 'Level not an integer or a valid string: None', + self.logger.setLevel, None) + self.assert_error_message( + TypeError, 'Level not an integer or a valid string: (0, 0)', + self.logger.setLevel, (0, 0)) + + def test_exception(self): + msg = 'testing exception: %r' + exc = None + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + self.logger.exception(msg, self.recording) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.ERROR) + self.assertEqual(record.msg, msg) + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + + def test_log_invalid_level_with_raise(self): + with support.swap_attr(logging, 'raiseExceptions', True): + self.assertRaises(TypeError, self.logger.log, '10', 'test message') + + def test_log_invalid_level_no_raise(self): + with support.swap_attr(logging, 'raiseExceptions', False): + self.logger.log('10', 'test message') # no exception happens + + def test_find_caller_with_stack_info(self): + called = [] + support.patch(self, logging.traceback, 'print_stack', + lambda f, file: called.append(file.getvalue())) + + self.logger.findCaller(stack_info=True) + + self.assertEqual(len(called), 1) + self.assertEqual('Stack (most recent call last):\n', called[0]) + + def test_find_caller_with_stacklevel(self): + the_level = 1 + trigger = self.logger.warning + + def innermost(): + trigger('test', stacklevel=the_level) + + def inner(): + innermost() + + def outer(): + inner() + + records = self.recording.records + outer() + self.assertEqual(records[-1].funcName, 'innermost') + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'inner') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'outer') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + root_logger = logging.getLogger() + root_logger.addHandler(self.recording) + trigger = logging.warning + outer() + self.assertEqual(records[-1].funcName, 'outer') + root_logger.removeHandler(self.recording) + trigger = self.logger.warning + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'test_find_caller_with_stacklevel') + self.assertGreater(records[-1].lineno, lineno) + + def test_make_record_with_extra_overwrite(self): + name = 'my record' + level = 13 + fn = lno = msg = args = exc_info = func = sinfo = None + rv = logging._logRecordFactory(name, level, fn, lno, msg, args, + exc_info, func, sinfo) + + for key in ('message', 'asctime') + tuple(rv.__dict__.keys()): + extra = {key: 'some value'} + self.assertRaises(KeyError, self.logger.makeRecord, name, level, + fn, lno, msg, args, exc_info, + extra=extra, sinfo=sinfo) + + def test_make_record_with_extra_no_overwrite(self): + name = 'my record' + level = 13 + fn = lno = msg = args = exc_info = func = sinfo = None + extra = {'valid_key': 'some value'} + result = self.logger.makeRecord(name, level, fn, lno, msg, args, + exc_info, extra=extra, sinfo=sinfo) + self.assertIn('valid_key', result.__dict__) + + def test_has_handlers(self): + self.assertTrue(self.logger.hasHandlers()) + + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + self.assertFalse(self.logger.hasHandlers()) + + def test_has_handlers_no_propagate(self): + child_logger = logging.getLogger('blah.child') + child_logger.propagate = False + self.assertFalse(child_logger.hasHandlers()) + + def test_is_enabled_for(self): + old_disable = self.logger.manager.disable + self.logger.manager.disable = 23 + self.addCleanup(setattr, self.logger.manager, 'disable', old_disable) + self.assertFalse(self.logger.isEnabledFor(22)) + + def test_is_enabled_for_disabled_logger(self): + old_disabled = self.logger.disabled + old_disable = self.logger.manager.disable + + self.logger.disabled = True + self.logger.manager.disable = 21 + + self.addCleanup(setattr, self.logger, 'disabled', old_disabled) + self.addCleanup(setattr, self.logger.manager, 'disable', old_disable) + + self.assertFalse(self.logger.isEnabledFor(22)) + + def test_root_logger_aliases(self): + root = logging.getLogger() + self.assertIs(root, logging.root) + self.assertIs(root, logging.getLogger(None)) + self.assertIs(root, logging.getLogger('')) + self.assertIs(root, logging.getLogger('root')) + self.assertIs(root, logging.getLogger('foo').root) + self.assertIs(root, logging.getLogger('foo.bar').root) + self.assertIs(root, logging.getLogger('foo').parent) + + self.assertIsNot(root, logging.getLogger('\0')) + self.assertIsNot(root, logging.getLogger('foo.bar').parent) + + def test_invalid_names(self): + self.assertRaises(TypeError, logging.getLogger, any) + self.assertRaises(TypeError, logging.getLogger, b'foo') + + def test_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + for name in ('', 'root', 'foo', 'foo.bar', 'baz.bar'): + logger = logging.getLogger(name) + s = pickle.dumps(logger, proto) + unpickled = pickle.loads(s) + self.assertIs(unpickled, logger) + + def test_caching(self): + root = self.root_logger + logger1 = logging.getLogger("abc") + logger2 = logging.getLogger("abc.def") + + # Set root logger level and ensure cache is empty + root.setLevel(logging.ERROR) + self.assertEqual(logger2.getEffectiveLevel(), logging.ERROR) + self.assertEqual(logger2._cache, {}) + + # Ensure cache is populated and calls are consistent + self.assertTrue(logger2.isEnabledFor(logging.ERROR)) + self.assertFalse(logger2.isEnabledFor(logging.DEBUG)) + self.assertEqual(logger2._cache, {logging.ERROR: True, logging.DEBUG: False}) + self.assertEqual(root._cache, {}) + self.assertTrue(logger2.isEnabledFor(logging.ERROR)) + + # Ensure root cache gets populated + self.assertEqual(root._cache, {}) + self.assertTrue(root.isEnabledFor(logging.ERROR)) + self.assertEqual(root._cache, {logging.ERROR: True}) + + # Set parent logger level and ensure caches are emptied + logger1.setLevel(logging.CRITICAL) + self.assertEqual(logger2.getEffectiveLevel(), logging.CRITICAL) + self.assertEqual(logger2._cache, {}) + + # Ensure logger2 uses parent logger's effective level + self.assertFalse(logger2.isEnabledFor(logging.ERROR)) + + # Set level to NOTSET and ensure caches are empty + logger2.setLevel(logging.NOTSET) + self.assertEqual(logger2.getEffectiveLevel(), logging.CRITICAL) + self.assertEqual(logger2._cache, {}) + self.assertEqual(logger1._cache, {}) + self.assertEqual(root._cache, {}) + + # Verify logger2 follows parent and not root + self.assertFalse(logger2.isEnabledFor(logging.ERROR)) + self.assertTrue(logger2.isEnabledFor(logging.CRITICAL)) + self.assertFalse(logger1.isEnabledFor(logging.ERROR)) + self.assertTrue(logger1.isEnabledFor(logging.CRITICAL)) + self.assertTrue(root.isEnabledFor(logging.ERROR)) + + # Disable logging in manager and ensure caches are clear + logging.disable() + self.assertEqual(logger2.getEffectiveLevel(), logging.CRITICAL) + self.assertEqual(logger2._cache, {}) + self.assertEqual(logger1._cache, {}) + self.assertEqual(root._cache, {}) + + # Ensure no loggers are enabled + self.assertFalse(logger1.isEnabledFor(logging.CRITICAL)) + self.assertFalse(logger2.isEnabledFor(logging.CRITICAL)) + self.assertFalse(root.isEnabledFor(logging.CRITICAL)) + + +class BaseFileTest(BaseTest): + "Base class for handler tests that write log files" + + def setUp(self): + BaseTest.setUp(self) + self.fn = make_temp_file(".log", "test_logging-2-") + self.rmfiles = [] + + def tearDown(self): + for fn in self.rmfiles: + os.unlink(fn) + if os.path.exists(self.fn): + os.unlink(self.fn) + BaseTest.tearDown(self) + + def assertLogFile(self, filename): + "Assert a log file is there and register it for deletion" + self.assertTrue(os.path.exists(filename), + msg="Log file %r does not exist" % filename) + self.rmfiles.append(filename) + + def next_rec(self): + return logging.LogRecord('n', logging.DEBUG, 'p', 1, + self.next_message(), None, None, None) + +class FileHandlerTest(BaseFileTest): + def test_delay(self): + os.unlink(self.fn) + fh = logging.FileHandler(self.fn, encoding='utf-8', delay=True) + self.assertIsNone(fh.stream) + self.assertFalse(os.path.exists(self.fn)) + fh.handle(logging.makeLogRecord({})) + self.assertIsNotNone(fh.stream) + self.assertTrue(os.path.exists(self.fn)) + fh.close() + + def test_emit_after_closing_in_write_mode(self): + # Issue #42378 + os.unlink(self.fn) + fh = logging.FileHandler(self.fn, encoding='utf-8', mode='w') + fh.setFormatter(logging.Formatter('%(message)s')) + fh.emit(self.next_rec()) # '1' + fh.close() + fh.emit(self.next_rec()) # '2' + with open(self.fn) as fp: + self.assertEqual(fp.read().strip(), '1') + +class RotatingFileHandlerTest(BaseFileTest): + def test_should_not_rollover(self): + # If file is empty rollover never occurs + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", maxBytes=1) + self.assertFalse(rh.shouldRollover(None)) + rh.close() + + # If maxBytes is zero rollover never occurs + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", maxBytes=0) + self.assertFalse(rh.shouldRollover(None)) + rh.close() + + with open(self.fn, 'wb') as f: + f.write(b'\n') + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", maxBytes=0) + self.assertFalse(rh.shouldRollover(None)) + rh.close() + + @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") + def test_should_not_rollover_non_file(self): + # bpo-45401 - test with special file + # We set maxBytes to 1 so that rollover would normally happen, except + # for the check for regular files + rh = logging.handlers.RotatingFileHandler( + os.devnull, encoding="utf-8", maxBytes=1) + self.assertFalse(rh.shouldRollover(self.next_rec())) + rh.close() + + def test_should_rollover(self): + with open(self.fn, 'wb') as f: + f.write(b'\n') + rh = logging.handlers.RotatingFileHandler(self.fn, encoding="utf-8", maxBytes=2) + self.assertTrue(rh.shouldRollover(self.next_rec())) + rh.close() + + def test_file_created(self): + # checks that the file is created and assumes it was created + # by us + os.unlink(self.fn) + rh = logging.handlers.RotatingFileHandler(self.fn, encoding="utf-8") + rh.emit(self.next_rec()) + self.assertLogFile(self.fn) + rh.close() + + def test_max_bytes(self, delay=False): + kwargs = {'delay': delay} if delay else {} + os.unlink(self.fn) + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", backupCount=2, maxBytes=100, **kwargs) + self.assertIs(os.path.exists(self.fn), not delay) + small = logging.makeLogRecord({'msg': 'a'}) + large = logging.makeLogRecord({'msg': 'b'*100}) + self.assertFalse(rh.shouldRollover(small)) + self.assertFalse(rh.shouldRollover(large)) + rh.emit(small) + self.assertLogFile(self.fn) + self.assertFalse(os.path.exists(self.fn + ".1")) + self.assertFalse(rh.shouldRollover(small)) + self.assertTrue(rh.shouldRollover(large)) + rh.emit(large) + self.assertTrue(os.path.exists(self.fn)) + self.assertLogFile(self.fn + ".1") + self.assertFalse(os.path.exists(self.fn + ".2")) + self.assertTrue(rh.shouldRollover(small)) + self.assertTrue(rh.shouldRollover(large)) + rh.close() + + def test_max_bytes_delay(self): + self.test_max_bytes(delay=True) + + def test_rollover_filenames(self): + def namer(name): + return name + ".test" + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", backupCount=2, maxBytes=1) + rh.namer = namer + rh.emit(self.next_rec()) + self.assertLogFile(self.fn) + self.assertFalse(os.path.exists(namer(self.fn + ".1"))) + rh.emit(self.next_rec()) + self.assertLogFile(namer(self.fn + ".1")) + self.assertFalse(os.path.exists(namer(self.fn + ".2"))) + rh.emit(self.next_rec()) + self.assertLogFile(namer(self.fn + ".2")) + self.assertFalse(os.path.exists(namer(self.fn + ".3"))) + rh.emit(self.next_rec()) + self.assertFalse(os.path.exists(namer(self.fn + ".3"))) + rh.close() + + def test_namer_rotator_inheritance(self): + class HandlerWithNamerAndRotator(logging.handlers.RotatingFileHandler): + def namer(self, name): + return name + ".test" + + def rotator(self, source, dest): + if os.path.exists(source): + os.replace(source, dest + ".rotated") + + rh = HandlerWithNamerAndRotator( + self.fn, encoding="utf-8", backupCount=2, maxBytes=1) + self.assertEqual(rh.namer(self.fn), self.fn + ".test") + rh.emit(self.next_rec()) + self.assertLogFile(self.fn) + rh.emit(self.next_rec()) + self.assertLogFile(rh.namer(self.fn + ".1") + ".rotated") + self.assertFalse(os.path.exists(rh.namer(self.fn + ".1"))) + rh.close() + + @support.requires_zlib() + def test_rotator(self): + def namer(name): + return name + ".gz" + + def rotator(source, dest): + with open(source, "rb") as sf: + data = sf.read() + compressed = zlib.compress(data, 9) + with open(dest, "wb") as df: + df.write(compressed) + os.remove(source) + + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", backupCount=2, maxBytes=1) + rh.rotator = rotator + rh.namer = namer + m1 = self.next_rec() + rh.emit(m1) + self.assertLogFile(self.fn) + m2 = self.next_rec() + rh.emit(m2) + fn = namer(self.fn + ".1") + self.assertLogFile(fn) + newline = os.linesep + with open(fn, "rb") as f: + compressed = f.read() + data = zlib.decompress(compressed) + self.assertEqual(data.decode("ascii"), m1.msg + newline) + rh.emit(self.next_rec()) + fn = namer(self.fn + ".2") + self.assertLogFile(fn) + with open(fn, "rb") as f: + compressed = f.read() + data = zlib.decompress(compressed) + self.assertEqual(data.decode("ascii"), m1.msg + newline) + rh.emit(self.next_rec()) + fn = namer(self.fn + ".2") + with open(fn, "rb") as f: + compressed = f.read() + data = zlib.decompress(compressed) + self.assertEqual(data.decode("ascii"), m2.msg + newline) + self.assertFalse(os.path.exists(namer(self.fn + ".3"))) + rh.close() + +class TimedRotatingFileHandlerTest(BaseFileTest): + # TODO: RUSTPYTHON + @unittest.skip("OS dependent bug") + @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") + def test_should_not_rollover(self): + # See bpo-45401. Should only ever rollover regular files + fh = logging.handlers.TimedRotatingFileHandler( + os.devnull, 'S', encoding="utf-8", backupCount=1) + time.sleep(1.1) # a little over a second ... + r = logging.makeLogRecord({'msg': 'testing - device file'}) + self.assertFalse(fh.shouldRollover(r)) + fh.close() + + # other test methods added below + def test_rollover(self): + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, 'S', encoding="utf-8", backupCount=1) + fmt = logging.Formatter('%(asctime)s %(message)s') + fh.setFormatter(fmt) + r1 = logging.makeLogRecord({'msg': 'testing - initial'}) + fh.emit(r1) + self.assertLogFile(self.fn) + time.sleep(1.1) # a little over a second ... + r2 = logging.makeLogRecord({'msg': 'testing - after delay'}) + fh.emit(r2) + fh.close() + # At this point, we should have a recent rotated file which we + # can test for the existence of. However, in practice, on some + # machines which run really slowly, we don't know how far back + # in time to go to look for the log file. So, we go back a fair + # bit, and stop as soon as we see a rotated file. In theory this + # could of course still fail, but the chances are lower. + found = False + now = datetime.datetime.now() + GO_BACK = 5 * 60 # seconds + for secs in range(GO_BACK): + prev = now - datetime.timedelta(seconds=secs) + fn = self.fn + prev.strftime(".%Y-%m-%d_%H-%M-%S") + found = os.path.exists(fn) + if found: + self.rmfiles.append(fn) + break + msg = 'No rotated files found, went back %d seconds' % GO_BACK + if not found: + # print additional diagnostics + dn, fn = os.path.split(self.fn) + files = [f for f in os.listdir(dn) if f.startswith(fn)] + print('Test time: %s' % now.strftime("%Y-%m-%d %H-%M-%S"), file=sys.stderr) + print('The only matching files are: %s' % files, file=sys.stderr) + for f in files: + print('Contents of %s:' % f) + path = os.path.join(dn, f) + with open(path, 'r') as tf: + print(tf.read()) + self.assertTrue(found, msg=msg) + + # TODO: RustPython + @unittest.skip("OS dependent bug") + def test_rollover_at_midnight(self, weekly=False): + os_helper.unlink(self.fn) + now = datetime.datetime.now() + atTime = now.time() + if not 0.1 < atTime.microsecond/1e6 < 0.9: + # The test requires all records to be emitted within + # the range of the same whole second. + time.sleep((0.1 - atTime.microsecond/1e6) % 1.0) + now = datetime.datetime.now() + atTime = now.time() + atTime = atTime.replace(microsecond=0) + fmt = logging.Formatter('%(asctime)s %(message)s') + when = f'W{now.weekday()}' if weekly else 'MIDNIGHT' + for i in range(3): + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when=when, atTime=atTime) + fh.setFormatter(fmt) + r2 = logging.makeLogRecord({'msg': f'testing1 {i}'}) + fh.emit(r2) + fh.close() + self.assertLogFile(self.fn) + with open(self.fn, encoding="utf-8") as f: + for i, line in enumerate(f): + self.assertIn(f'testing1 {i}', line) + + os.utime(self.fn, (now.timestamp() - 1,)*2) + for i in range(2): + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when=when, atTime=atTime) + fh.setFormatter(fmt) + r2 = logging.makeLogRecord({'msg': f'testing2 {i}'}) + fh.emit(r2) + fh.close() + rolloverDate = now - datetime.timedelta(days=7 if weekly else 1) + otherfn = f'{self.fn}.{rolloverDate:%Y-%m-%d}' + self.assertLogFile(otherfn) + with open(self.fn, encoding="utf-8") as f: + for i, line in enumerate(f): + self.assertIn(f'testing2 {i}', line) + with open(otherfn, encoding="utf-8") as f: + for i, line in enumerate(f): + self.assertIn(f'testing1 {i}', line) + + # TODO: RustPython + @unittest.skip("OS dependent bug") + def test_rollover_at_weekday(self): + self.test_rollover_at_midnight(weekly=True) + + def test_invalid(self): + assertRaises = self.assertRaises + assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, + self.fn, 'X', encoding="utf-8", delay=True) + assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, + self.fn, 'W', encoding="utf-8", delay=True) + assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, + self.fn, 'W7', encoding="utf-8", delay=True) + + # TODO: Test for utc=False. + def test_compute_rollover_daily_attime(self): + currentTime = 0 + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', + utc=True, atTime=None) + try: + actual = rh.computeRollover(currentTime) + self.assertEqual(actual, currentTime + 24 * 60 * 60) + + actual = rh.computeRollover(currentTime + 24 * 60 * 60 - 1) + self.assertEqual(actual, currentTime + 24 * 60 * 60) + + actual = rh.computeRollover(currentTime + 24 * 60 * 60) + self.assertEqual(actual, currentTime + 48 * 60 * 60) + + actual = rh.computeRollover(currentTime + 25 * 60 * 60) + self.assertEqual(actual, currentTime + 48 * 60 * 60) + finally: + rh.close() + + atTime = datetime.time(12, 0, 0) + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', + utc=True, atTime=atTime) + try: + actual = rh.computeRollover(currentTime) + self.assertEqual(actual, currentTime + 12 * 60 * 60) + + actual = rh.computeRollover(currentTime + 12 * 60 * 60 - 1) + self.assertEqual(actual, currentTime + 12 * 60 * 60) + + actual = rh.computeRollover(currentTime + 12 * 60 * 60) + self.assertEqual(actual, currentTime + 36 * 60 * 60) + + actual = rh.computeRollover(currentTime + 13 * 60 * 60) + self.assertEqual(actual, currentTime + 36 * 60 * 60) + finally: + rh.close() + + # TODO: Test for utc=False. + def test_compute_rollover_weekly_attime(self): + currentTime = int(time.time()) + today = currentTime - currentTime % 86400 + + atTime = datetime.time(12, 0, 0) + + wday = time.gmtime(today).tm_wday + for day in range(7): + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W%d' % day, interval=1, backupCount=0, + utc=True, atTime=atTime) + try: + if wday > day: + # The rollover day has already passed this week, so we + # go over into next week + expected = (7 - wday + day) + else: + expected = (day - wday) + # At this point expected is in days from now, convert to seconds + expected *= 24 * 60 * 60 + # Add in the rollover time + expected += 12 * 60 * 60 + # Add in adjustment for today + expected += today + + actual = rh.computeRollover(today) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + + actual = rh.computeRollover(today + 12 * 60 * 60 - 1) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + + if day == wday: + # goes into following week + expected += 7 * 24 * 60 * 60 + actual = rh.computeRollover(today + 12 * 60 * 60) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + + actual = rh.computeRollover(today + 13 * 60 * 60) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + finally: + rh.close() + + def test_compute_files_to_delete(self): + # See bpo-46063 for background + wd = tempfile.mkdtemp(prefix='test_logging_') + self.addCleanup(shutil.rmtree, wd) + times = [] + dt = datetime.datetime.now() + for i in range(10): + times.append(dt.strftime('%Y-%m-%d_%H-%M-%S')) + dt += datetime.timedelta(seconds=5) + prefixes = ('a.b', 'a.b.c', 'd.e', 'd.e.f', 'g') + files = [] + rotators = [] + for prefix in prefixes: + p = os.path.join(wd, '%s.log' % prefix) + rotator = logging.handlers.TimedRotatingFileHandler(p, when='s', + interval=5, + backupCount=7, + delay=True) + rotators.append(rotator) + if prefix.startswith('a.b'): + for t in times: + files.append('%s.log.%s' % (prefix, t)) + elif prefix.startswith('d.e'): + def namer(filename): + dirname, basename = os.path.split(filename) + basename = basename.replace('.log', '') + '.log' + return os.path.join(dirname, basename) + rotator.namer = namer + for t in times: + files.append('%s.%s.log' % (prefix, t)) + elif prefix == 'g': + def namer(filename): + dirname, basename = os.path.split(filename) + basename = 'g' + basename[6:] + '.oldlog' + return os.path.join(dirname, basename) + rotator.namer = namer + for t in times: + files.append('g%s.oldlog' % t) + # Create empty files + for fn in files: + p = os.path.join(wd, fn) + with open(p, 'wb') as f: + pass + # Now the checks that only the correct files are offered up for deletion + for i, prefix in enumerate(prefixes): + rotator = rotators[i] + candidates = rotator.getFilesToDelete() + self.assertEqual(len(candidates), 3, candidates) + if prefix.startswith('a.b'): + p = '%s.log.' % prefix + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.startswith(p)) + elif prefix.startswith('d.e'): + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.endswith('.log'), fn) + self.assertTrue(fn.startswith(prefix + '.') and + fn[len(prefix) + 2].isdigit()) + elif prefix == 'g': + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.endswith('.oldlog')) + self.assertTrue(fn.startswith('g') and fn[1].isdigit()) + + def test_compute_files_to_delete_same_filename_different_extensions(self): + # See GH-93205 for background + wd = pathlib.Path(tempfile.mkdtemp(prefix='test_logging_')) + self.addCleanup(shutil.rmtree, wd) + times = [] + dt = datetime.datetime.now() + n_files = 10 + for _ in range(n_files): + times.append(dt.strftime('%Y-%m-%d_%H-%M-%S')) + dt += datetime.timedelta(seconds=5) + prefixes = ('a.log', 'a.log.b') + files = [] + rotators = [] + for i, prefix in enumerate(prefixes): + backupCount = i+1 + rotator = logging.handlers.TimedRotatingFileHandler(wd / prefix, when='s', + interval=5, + backupCount=backupCount, + delay=True) + rotators.append(rotator) + for t in times: + files.append('%s.%s' % (prefix, t)) + for t in times: + files.append('a.log.%s.c' % t) + # Create empty files + for f in files: + (wd / f).touch() + # Now the checks that only the correct files are offered up for deletion + for i, prefix in enumerate(prefixes): + backupCount = i+1 + rotator = rotators[i] + candidates = rotator.getFilesToDelete() + self.assertEqual(len(candidates), n_files - backupCount, candidates) + matcher = re.compile(r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\Z") + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.startswith(prefix+'.')) + suffix = fn[(len(prefix)+1):] + self.assertRegex(suffix, matcher) + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_MIDNIGHT_local(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False) + + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 3, 12, 0, 0)) + + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 5, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 3, 10, 11, 59, 59), DT(2012, 3, 10, 12, 0)) + test(DT(2012, 3, 10, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 10, 13, 0), DT(2012, 3, 11, 12, 0)) + + test(DT(2012, 11, 3, 11, 59, 59), DT(2012, 11, 3, 12, 0)) + test(DT(2012, 11, 3, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 11, 3, 13, 0), DT(2012, 11, 4, 12, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(2, 0, 0)) + + test(DT(2012, 3, 10, 1, 59, 59), DT(2012, 3, 10, 2, 0)) + # 2:00:00 is the same as 3:00:00 at 2012-3-11. + test(DT(2012, 3, 10, 2, 0), DT(2012, 3, 11, 3, 0)) + test(DT(2012, 3, 10, 3, 0), DT(2012, 3, 11, 3, 0)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 0)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 12, 2, 0)) + test(DT(2012, 3, 11, 4, 0), DT(2012, 3, 12, 2, 0)) + + test(DT(2012, 11, 3, 1, 59, 59), DT(2012, 11, 3, 2, 0)) + test(DT(2012, 11, 3, 2, 0), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 3, 3, 0), DT(2012, 11, 4, 2, 0)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 5, 2, 0)) + test(DT(2012, 11, 4, 3, 0), DT(2012, 11, 5, 2, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(2, 30, 0)) + + test(DT(2012, 3, 10, 2, 29, 59), DT(2012, 3, 10, 2, 30)) + # No time 2:30:00 at 2012-3-11. + test(DT(2012, 3, 10, 2, 30), DT(2012, 3, 11, 3, 30)) + test(DT(2012, 3, 10, 3, 0), DT(2012, 3, 11, 3, 30)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 12, 2, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 12, 2, 30)) + + test(DT(2012, 11, 3, 2, 29, 59), DT(2012, 11, 3, 2, 30)) + test(DT(2012, 11, 3, 2, 30), DT(2012, 11, 4, 2, 30)) + test(DT(2012, 11, 3, 3, 0), DT(2012, 11, 4, 2, 30)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(1, 30, 0)) + + test(DT(2012, 3, 11, 1, 29, 59), DT(2012, 3, 11, 1, 30)) + test(DT(2012, 3, 11, 1, 30), DT(2012, 3, 12, 1, 30)) + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 12, 1, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 12, 1, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 12, 1, 30)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 29, 59), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 30), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 5, 1, 30)) + # It is weird, but the rollover date jumps back from 2012-11-5 + # to 2012-11-4. + test(DT(2012, 11, 4, 1, 0, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 29, 59, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 30, fold=1), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 2, 30), DT(2012, 11, 5, 1, 30)) + + fh.close() + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_W6_local(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False) + + test(DT(2012, 3, 4, 23, 59, 59), DT(2012, 3, 5, 0, 0)) + test(DT(2012, 3, 5, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 5, 1, 0), DT(2012, 3, 12, 0, 0)) + + test(DT(2012, 10, 28, 23, 59, 59), DT(2012, 10, 29, 0, 0)) + test(DT(2012, 10, 29, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 10, 29, 1, 0), DT(2012, 11, 5, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(0, 0, 0)) + + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 3, 18, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 3, 18, 0, 0)) + + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 11, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 11, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 3, 4, 11, 59, 59), DT(2012, 3, 4, 12, 0)) + test(DT(2012, 3, 4, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 4, 13, 0), DT(2012, 3, 11, 12, 0)) + + test(DT(2012, 10, 28, 11, 59, 59), DT(2012, 10, 28, 12, 0)) + test(DT(2012, 10, 28, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 10, 28, 13, 0), DT(2012, 11, 4, 12, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(2, 0, 0)) + + test(DT(2012, 3, 4, 1, 59, 59), DT(2012, 3, 4, 2, 0)) + # 2:00:00 is the same as 3:00:00 at 2012-3-11. + test(DT(2012, 3, 4, 2, 0), DT(2012, 3, 11, 3, 0)) + test(DT(2012, 3, 4, 3, 0), DT(2012, 3, 11, 3, 0)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 0)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 18, 2, 0)) + test(DT(2012, 3, 11, 4, 0), DT(2012, 3, 18, 2, 0)) + + test(DT(2012, 10, 28, 1, 59, 59), DT(2012, 10, 28, 2, 0)) + test(DT(2012, 10, 28, 2, 0), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 10, 28, 3, 0), DT(2012, 11, 4, 2, 0)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 11, 2, 0)) + test(DT(2012, 11, 4, 3, 0), DT(2012, 11, 11, 2, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(2, 30, 0)) + + test(DT(2012, 3, 4, 2, 29, 59), DT(2012, 3, 4, 2, 30)) + # No time 2:30:00 at 2012-3-11. + test(DT(2012, 3, 4, 2, 30), DT(2012, 3, 11, 3, 30)) + test(DT(2012, 3, 4, 3, 0), DT(2012, 3, 11, 3, 30)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 18, 2, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 18, 2, 30)) + + test(DT(2012, 10, 28, 2, 29, 59), DT(2012, 10, 28, 2, 30)) + test(DT(2012, 10, 28, 2, 30), DT(2012, 11, 4, 2, 30)) + test(DT(2012, 10, 28, 3, 0), DT(2012, 11, 4, 2, 30)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(1, 30, 0)) + + test(DT(2012, 3, 11, 1, 29, 59), DT(2012, 3, 11, 1, 30)) + test(DT(2012, 3, 11, 1, 30), DT(2012, 3, 18, 1, 30)) + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 18, 1, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 18, 1, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 18, 1, 30)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 29, 59), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 30), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 11, 1, 30)) + # It is weird, but the rollover date jumps back from 2012-11-11 + # to 2012-11-4. + test(DT(2012, 11, 4, 1, 0, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 29, 59, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 30, fold=1), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 2, 30), DT(2012, 11, 11, 1, 30)) + + fh.close() + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_MIDNIGHT_local_interval(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, interval=3) + + test(DT(2012, 3, 8, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 3, 9, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 9, 1, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 13, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 3, 14, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 3, 14, 0, 0)) + + test(DT(2012, 11, 1, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 11, 2, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 11, 2, 1, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 6, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 7, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 7, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, interval=3, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 3, 8, 11, 59, 59), DT(2012, 3, 10, 12, 0)) + test(DT(2012, 3, 8, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 8, 13, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 10, 11, 59, 59), DT(2012, 3, 12, 12, 0)) + test(DT(2012, 3, 10, 12, 0), DT(2012, 3, 13, 12, 0)) + test(DT(2012, 3, 10, 13, 0), DT(2012, 3, 13, 12, 0)) + + test(DT(2012, 11, 1, 11, 59, 59), DT(2012, 11, 3, 12, 0)) + test(DT(2012, 11, 1, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 11, 1, 13, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 11, 3, 11, 59, 59), DT(2012, 11, 5, 12, 0)) + test(DT(2012, 11, 3, 12, 0), DT(2012, 11, 6, 12, 0)) + test(DT(2012, 11, 3, 13, 0), DT(2012, 11, 6, 12, 0)) + + fh.close() + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_W6_local_interval(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, interval=3) + + test(DT(2012, 2, 19, 23, 59, 59), DT(2012, 3, 5, 0, 0)) + test(DT(2012, 2, 20, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 2, 20, 1, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 4, 23, 59, 59), DT(2012, 3, 19, 0, 0)) + test(DT(2012, 3, 5, 0, 0), DT(2012, 3, 26, 0, 0)) + test(DT(2012, 3, 5, 1, 0), DT(2012, 3, 26, 0, 0)) + + test(DT(2012, 10, 14, 23, 59, 59), DT(2012, 10, 29, 0, 0)) + test(DT(2012, 10, 15, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 10, 15, 1, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 10, 28, 23, 59, 59), DT(2012, 11, 12, 0, 0)) + test(DT(2012, 10, 29, 0, 0), DT(2012, 11, 19, 0, 0)) + test(DT(2012, 10, 29, 1, 0), DT(2012, 11, 19, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, interval=3, + atTime=datetime.time(0, 0, 0)) + + test(DT(2012, 2, 25, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 2, 26, 0, 0), DT(2012, 3, 18, 0, 0)) + test(DT(2012, 2, 26, 1, 0), DT(2012, 3, 18, 0, 0)) + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 25, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 4, 1, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 4, 1, 0, 0)) + + test(DT(2012, 10, 20, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 10, 21, 0, 0), DT(2012, 11, 11, 0, 0)) + test(DT(2012, 10, 21, 1, 0), DT(2012, 11, 11, 0, 0)) + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 18, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 25, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 25, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, interval=3, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 2, 18, 11, 59, 59), DT(2012, 3, 4, 12, 0)) + test(DT(2012, 2, 19, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 2, 19, 13, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 4, 11, 59, 59), DT(2012, 3, 18, 12, 0)) + test(DT(2012, 3, 4, 12, 0), DT(2012, 3, 25, 12, 0)) + test(DT(2012, 3, 4, 13, 0), DT(2012, 3, 25, 12, 0)) + + test(DT(2012, 10, 14, 11, 59, 59), DT(2012, 10, 28, 12, 0)) + test(DT(2012, 10, 14, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 10, 14, 13, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 10, 28, 11, 59, 59), DT(2012, 11, 11, 12, 0)) + test(DT(2012, 10, 28, 12, 0), DT(2012, 11, 18, 12, 0)) + test(DT(2012, 10, 28, 13, 0), DT(2012, 11, 18, 12, 0)) + + fh.close() + + +def secs(**kw): + return datetime.timedelta(**kw) // datetime.timedelta(seconds=1) + +for when, exp in (('S', 1), + ('M', 60), + ('H', 60 * 60), + ('D', 60 * 60 * 24), + ('MIDNIGHT', 60 * 60 * 24), + # current time (epoch start) is a Thursday, W0 means Monday + ('W0', secs(days=4, hours=24)), + ): + for interval in 1, 3: + def test_compute_rollover(self, when=when, interval=interval, exp=exp): + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when=when, interval=interval, backupCount=0, utc=True) + currentTime = 0.0 + actual = rh.computeRollover(currentTime) + if when.startswith('W'): + exp += secs(days=7*(interval-1)) + else: + exp *= interval + if exp != actual: + # Failures occur on some systems for MIDNIGHT and W0. + # Print detailed calculation for MIDNIGHT so we can try to see + # what's going on + if when == 'MIDNIGHT': + try: + if rh.utc: + t = time.gmtime(currentTime) + else: + t = time.localtime(currentTime) + currentHour = t[3] + currentMinute = t[4] + currentSecond = t[5] + # r is the number of seconds left between now and midnight + r = logging.handlers._MIDNIGHT - ((currentHour * 60 + + currentMinute) * 60 + + currentSecond) + result = currentTime + r + print('t: %s (%s)' % (t, rh.utc), file=sys.stderr) + print('currentHour: %s' % currentHour, file=sys.stderr) + print('currentMinute: %s' % currentMinute, file=sys.stderr) + print('currentSecond: %s' % currentSecond, file=sys.stderr) + print('r: %s' % r, file=sys.stderr) + print('result: %s' % result, file=sys.stderr) + except Exception as e: + print('exception in diagnostic code: %s' % e, file=sys.stderr) + self.assertEqual(exp, actual) + rh.close() + name = "test_compute_rollover_%s" % when + if interval > 1: + name += "_interval" + test_compute_rollover.__name__ = name + setattr(TimedRotatingFileHandlerTest, name, test_compute_rollover) + + +@unittest.skipUnless(win32evtlog, 'win32evtlog/win32evtlogutil/pywintypes required for this test.') +class NTEventLogHandlerTest(BaseTest): + def test_basic(self): + logtype = 'Application' + elh = win32evtlog.OpenEventLog(None, logtype) + num_recs = win32evtlog.GetNumberOfEventLogRecords(elh) + + try: + h = logging.handlers.NTEventLogHandler('test_logging') + except pywintypes.error as e: + if e.winerror == 5: # access denied + raise unittest.SkipTest('Insufficient privileges to run test') + raise + + r = logging.makeLogRecord({'msg': 'Test Log Message'}) + h.handle(r) + h.close() + # Now see if the event is recorded + self.assertLess(num_recs, win32evtlog.GetNumberOfEventLogRecords(elh)) + flags = win32evtlog.EVENTLOG_BACKWARDS_READ | \ + win32evtlog.EVENTLOG_SEQUENTIAL_READ + found = False + GO_BACK = 100 + events = win32evtlog.ReadEventLog(elh, flags, GO_BACK) + for e in events: + if e.SourceName != 'test_logging': + continue + msg = win32evtlogutil.SafeFormatMessage(e, logtype) + if msg != 'Test Log Message\r\n': + continue + found = True + break + msg = 'Record not found in event log, went back %d records' % GO_BACK + self.assertTrue(found, msg=msg) + + +class MiscTestCase(unittest.TestCase): + def test__all__(self): + not_exported = { + 'logThreads', 'logMultiprocessing', 'logProcesses', 'currentframe', + 'PercentStyle', 'StrFormatStyle', 'StringTemplateStyle', + 'Filterer', 'PlaceHolder', 'Manager', 'RootLogger', 'root', + 'threading', 'logAsyncioTasks'} + support.check__all__(self, logging, not_exported=not_exported) + + +# Set the locale to the platform-dependent default. I have no idea +# why the test does this, but in any case we save the current locale +# first and restore it at the end. +def setUpModule(): + unittest.enterModuleContext(support.run_with_locale('LC_ALL', '')) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 2f64180652..fa79456ed4 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -4,6 +4,7 @@ from test.support import verbose, requires_IEEE_754 from test import support import unittest +import fractions import itertools import decimal import math @@ -186,6 +187,9 @@ def result_check(expected, got, ulp_tol=5, abs_tol=0.0): # Check exactly equal (applies also to strings representing exceptions) if got == expected: + if not got and not expected: + if math.copysign(1, got) != math.copysign(1, expected): + return f"expected {expected}, got {got} (zero has wrong sign)" return None failure = "not equal" @@ -234,6 +238,10 @@ def __init__(self, value): def __index__(self): return self.value +class BadDescr: + def __get__(self, obj, objtype=None): + raise ValueError + class MathTests(unittest.TestCase): def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0): @@ -323,6 +331,7 @@ def testAtan2(self): self.ftest('atan2(0, 1)', math.atan2(0, 1), 0) self.ftest('atan2(1, 1)', math.atan2(1, 1), math.pi/4) self.ftest('atan2(1, 0)', math.atan2(1, 0), math.pi/2) + self.ftest('atan2(1, -1)', math.atan2(1, -1), 3*math.pi/4) # math.atan2(0, x) self.ftest('atan2(0., -inf)', math.atan2(0., NINF), math.pi) @@ -416,16 +425,22 @@ def __ceil__(self): return 42 class TestNoCeil: pass + class TestBadCeil: + __ceil__ = BadDescr() self.assertEqual(math.ceil(TestCeil()), 42) self.assertEqual(math.ceil(FloatCeil()), 42) self.assertEqual(math.ceil(FloatLike(42.5)), 43) self.assertRaises(TypeError, math.ceil, TestNoCeil()) + self.assertRaises(ValueError, math.ceil, TestBadCeil()) t = TestNoCeil() t.__ceil__ = lambda *args: args self.assertRaises(TypeError, math.ceil, t) self.assertRaises(TypeError, math.ceil, t, 0) + self.assertEqual(math.ceil(FloatLike(+1.0)), +1.0) + self.assertEqual(math.ceil(FloatLike(-1.0)), -1.0) + @requires_IEEE_754 def testCopysign(self): self.assertEqual(math.copysign(1, 42), 1.0) @@ -566,16 +581,22 @@ def __floor__(self): return 42 class TestNoFloor: pass + class TestBadFloor: + __floor__ = BadDescr() self.assertEqual(math.floor(TestFloor()), 42) self.assertEqual(math.floor(FloatFloor()), 42) self.assertEqual(math.floor(FloatLike(41.9)), 41) self.assertRaises(TypeError, math.floor, TestNoFloor()) + self.assertRaises(ValueError, math.floor, TestBadFloor()) t = TestNoFloor() t.__floor__ = lambda *args: args self.assertRaises(TypeError, math.floor, t) self.assertRaises(TypeError, math.floor, t, 0) + self.assertEqual(math.floor(FloatLike(+1.0)), +1.0) + self.assertEqual(math.floor(FloatLike(-1.0)), -1.0) + def testFmod(self): self.assertRaises(TypeError, math.fmod) self.ftest('fmod(10, 1)', math.fmod(10, 1), 0.0) @@ -597,6 +618,7 @@ def testFmod(self): self.assertEqual(math.fmod(-3.0, NINF), -3.0) self.assertEqual(math.fmod(0.0, 3.0), 0.0) self.assertEqual(math.fmod(0.0, NINF), 0.0) + self.assertRaises(ValueError, math.fmod, INF, INF) def testFrexp(self): self.assertRaises(TypeError, math.frexp) @@ -638,7 +660,7 @@ def testFsum(self): def msum(iterable): """Full precision summation. Compute sum(iterable) without any intermediate accumulation of error. Based on the 'lsum' function - at http://code.activestate.com/recipes/393090/ + at https://code.activestate.com/recipes/393090-binary-floating-point-summation-accurate-to-full-p/ """ tmant, texp = 0, 0 @@ -666,6 +688,7 @@ def msum(iterable): ([], 0.0), ([0.0], 0.0), ([1e100, 1.0, -1e100, 1e-100, 1e50, -1.0, -1e50], 1e-100), + ([1e100, 1.0, -1e100, 1e-100, 1e50, -1, -1e50], 1e-100), ([2.0**53, -0.5, -2.0**-54], 2.0**53-1.0), ([2.0**53, 1.0, 2.0**-100], 2.0**53+2.0), ([2.0**53+10.0, 1.0, 2.0**-100], 2.0**53+12.0), @@ -713,6 +736,22 @@ def msum(iterable): s = msum(vals) self.assertEqual(msum(vals), math.fsum(vals)) + self.assertEqual(math.fsum([1.0, math.inf]), math.inf) + self.assertTrue(math.isnan(math.fsum([math.nan, 1.0]))) + self.assertEqual(math.fsum([1e100, FloatLike(1.0), -1e100, 1e-100, + 1e50, FloatLike(-1.0), -1e50]), 1e-100) + self.assertRaises(OverflowError, math.fsum, [1e+308, 1e+308]) + self.assertRaises(ValueError, math.fsum, [math.inf, -math.inf]) + self.assertRaises(TypeError, math.fsum, ['spam']) + self.assertRaises(TypeError, math.fsum, 1) + self.assertRaises(OverflowError, math.fsum, [10**1000]) + + def bad_iter(): + yield 1.0 + raise ZeroDivisionError + + self.assertRaises(ZeroDivisionError, math.fsum, bad_iter()) + def testGcd(self): gcd = math.gcd self.assertEqual(gcd(0, 0), 0) @@ -773,9 +812,13 @@ def testHypot(self): # Test allowable types (those with __float__) self.assertEqual(hypot(12.0, 5.0), 13.0) self.assertEqual(hypot(12, 5), 13) + self.assertEqual(hypot(0.75, -1), 1.25) + self.assertEqual(hypot(-1, 0.75), 1.25) + self.assertEqual(hypot(0.75, FloatLike(-1.)), 1.25) + self.assertEqual(hypot(FloatLike(-1.), 0.75), 1.25) self.assertEqual(hypot(Decimal(12), Decimal(5)), 13) self.assertEqual(hypot(Fraction(12, 32), Fraction(5, 32)), Fraction(13, 32)) - self.assertEqual(hypot(bool(1), bool(0), bool(1), bool(1)), math.sqrt(3)) + self.assertEqual(hypot(True, False, True, True, True), 2.0) # Test corner cases self.assertEqual(hypot(0.0, 0.0), 0.0) # Max input is zero @@ -830,6 +873,8 @@ def testHypot(self): scale = FLOAT_MIN / 2.0 ** exp self.assertEqual(math.hypot(4*scale, 3*scale), 5*scale) + self.assertRaises(TypeError, math.hypot, *([1.0]*18), 'spam') + @requires_IEEE_754 @unittest.skipIf(HAVE_DOUBLE_ROUNDING, "hypot() loses accuracy on machines with double rounding") @@ -922,12 +967,16 @@ def testDist(self): # Test allowable types (those with __float__) self.assertEqual(dist((14.0, 1.0), (2.0, -4.0)), 13.0) self.assertEqual(dist((14, 1), (2, -4)), 13) + self.assertEqual(dist((FloatLike(14.), 1), (2, -4)), 13) + self.assertEqual(dist((11, 1), (FloatLike(-1.), -4)), 13) + self.assertEqual(dist((14, FloatLike(-1.)), (2, -6)), 13) + self.assertEqual(dist((14, -1), (2, -6)), 13) self.assertEqual(dist((D(14), D(1)), (D(2), D(-4))), D(13)) self.assertEqual(dist((F(14, 32), F(1, 32)), (F(2, 32), F(-4, 32))), F(13, 32)) - self.assertEqual(dist((True, True, False, True, False), - (True, False, True, True, False)), - sqrt(2.0)) + self.assertEqual(dist((True, True, False, False, True, True), + (True, False, True, False, False, False)), + 2.0) # Test corner cases self.assertEqual(dist((13.25, 12.5, -3.25), @@ -965,6 +1014,8 @@ class T(tuple): dist((1, 2, 3, 4), (5, 6, 7)) with self.assertRaises(ValueError): # Check dimension agree dist((1, 2, 3), (4, 5, 6, 7)) + with self.assertRaises(TypeError): + dist((1,)*17 + ("spam",), (1,)*18) with self.assertRaises(TypeError): # Rejects invalid types dist("abc", "xyz") int_too_big_for_float = 10 ** (sys.float_info.max_10_exp + 5) @@ -972,6 +1023,16 @@ class T(tuple): dist((1, int_too_big_for_float), (2, 3)) with self.assertRaises((ValueError, OverflowError)): dist((2, 3), (1, int_too_big_for_float)) + with self.assertRaises(TypeError): + dist((1,), 2) + with self.assertRaises(TypeError): + dist([1], 2) + + class BadFloat: + __float__ = BadDescr() + + with self.assertRaises(ValueError): + dist([1], [BadFloat()]) # Verify that the one dimensional case is equivalent to abs() for i in range(20): @@ -1110,6 +1171,7 @@ def test_lcm(self): def testLdexp(self): self.assertRaises(TypeError, math.ldexp) + self.assertRaises(TypeError, math.ldexp, 2.0, 1.1) self.ftest('ldexp(0,1)', math.ldexp(0,1), 0) self.ftest('ldexp(1,1)', math.ldexp(1,1), 2) self.ftest('ldexp(1,-1)', math.ldexp(1,-1), 0.5) @@ -1142,6 +1204,7 @@ def testLdexp(self): def testLog(self): self.assertRaises(TypeError, math.log) + self.assertRaises(TypeError, math.log, 1, 2, 3) self.ftest('log(1/e)', math.log(1/math.e), -1) self.ftest('log(1)', math.log(1), 0) self.ftest('log(e)', math.log(math.e), 1) @@ -1152,6 +1215,7 @@ def testLog(self): 2302.5850929940457) self.assertRaises(ValueError, math.log, -1.5) self.assertRaises(ValueError, math.log, -10**1000) + self.assertRaises(ValueError, math.log, 10, -10) self.assertRaises(ValueError, math.log, NINF) self.assertEqual(math.log(INF), INF) self.assertTrue(math.isnan(math.log(NAN))) @@ -1202,6 +1266,277 @@ def testLog10(self): self.assertEqual(math.log(INF), INF) self.assertTrue(math.isnan(math.log10(NAN))) + def testSumProd(self): + sumprod = math.sumprod + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + # Core functionality + self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140) + self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5) + self.assertEqual(sumprod([], []), 0) + self.assertEqual(sumprod([-1], [1.]), -1) + self.assertEqual(sumprod([1.], [-1]), -1) + + # Type preservation and coercion + for v in [ + (10, 20, 30), + (1.5, -2.5), + (Fraction(3, 5), Fraction(4, 5)), + (Decimal(3.5), Decimal(4.5)), + (2.5, 10), # float/int + (2.5, Fraction(3, 5)), # float/fraction + (25, Fraction(3, 5)), # int/fraction + (25, Decimal(4.5)), # int/decimal + ]: + for p, q in [(v, v), (v, v[::-1])]: + with self.subTest(p=p, q=q): + expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True)) + actual = sumprod(p, q) + self.assertEqual(expected, actual) + self.assertEqual(type(expected), type(actual)) + + # Bad arguments + self.assertRaises(TypeError, sumprod) # No args + self.assertRaises(TypeError, sumprod, []) # One arg + self.assertRaises(TypeError, sumprod, [], [], []) # Three args + self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable + self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable + self.assertRaises(TypeError, sumprod, ['x'], [1.0]) + + # Uneven lengths + self.assertRaises(ValueError, sumprod, [10, 20], [30]) + self.assertRaises(ValueError, sumprod, [10], [20, 30]) + + # Overflows + self.assertEqual(sumprod([10**20], [1]), 10**20) + self.assertEqual(sumprod([1], [10**20]), 10**20) + self.assertEqual(sumprod([10**10], [10**10]), 10**20) + self.assertEqual(sumprod([10**7]*10**5, [10**7]*10**5), 10**19) + self.assertRaises(OverflowError, sumprod, [10**1000], [1.0]) + self.assertRaises(OverflowError, sumprod, [1.0], [10**1000]) + + # Error in iterator + def raise_after(n): + for i in range(n): + yield i + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod(range(10), raise_after(5)) + with self.assertRaises(RuntimeError): + sumprod(raise_after(5), range(10)) + + from test.test_iter import BasicIterClass + + self.assertEqual(sumprod(BasicIterClass(1), [1]), 0) + self.assertEqual(sumprod([1], BasicIterClass(1)), 0) + + # Error in multiplication + class BadMultiply: + def __mul__(self, other): + raise RuntimeError + def __rmul__(self, other): + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod([10, BadMultiply(), 30], [1, 2, 3]) + with self.assertRaises(RuntimeError): + sumprod([1, 2, 3], [10, BadMultiply(), 30]) + + # Error in addition + with self.assertRaises(TypeError): + sumprod(['abc', 3], [5, 10]) + with self.assertRaises(TypeError): + sumprod([5, 10], ['abc', 3]) + + # Special values should give the same as the pure python recipe + self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf) + self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf]))) + self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan]))) + + # Error cases that arose during development + args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952)) + self.assertEqual(sumprod(*args), 0.0) + + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "sumprod() accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Other implementations may choose a different algorithm + def test_sumprod_accuracy(self): + sumprod = math.sumprod + self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0) + self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0) + self.assertEqual(sumprod([True, False] * 10, [0.1] * 20), 1.0) + self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0) + + @support.requires_resource('cpu') + def test_sumprod_stress(self): + sumprod = math.sumprod + product = itertools.product + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + class Int(int): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Int({int(self)})' + + class Flt(float): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Flt({int(self)})' + + def baseline_sumprod(p, q): + """This defines the target behavior including exceptions and special values. + However, it is subject to rounding errors, so float inputs should be exactly + representable with only a few bits. + """ + total = 0 + for p_i, q_i in zip(p, q, strict=True): + total += p_i * q_i + return total + + def run(func, *args): + "Make comparing functions easier. Returns error status, type, and result." + try: + result = func(*args) + except (AssertionError, NameError): + raise + except Exception as e: + return type(e), None, 'None' + return None, type(result), repr(result) + + pools = [ + (-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)), + (5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125), + (-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333, + 5.25, -3.25, -3.0*2**(-333), 3, 2**513), + (3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14, + 9, 3+4j, Flt(13), 0.0), + (13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8), + Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)), + (Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0), + Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5), + (-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538, + 2*2**-513), + (-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25), + (11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)), + ] + + for pool in pools: + for size in range(4): + for args1 in product(pool, repeat=size): + for args2 in product(pool, repeat=size): + args = (args1, args2) + self.assertEqual( + run(baseline_sumprod, *args), + run(sumprod, *args), + args, + ) + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "sumprod() accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Other implementations may choose a different algorithm + @support.requires_resource('cpu') + def test_sumprod_extended_precision_accuracy(self): + import operator + from fractions import Fraction + from itertools import starmap + from collections import namedtuple + from math import log2, exp2, fabs + from random import choices, uniform, shuffle + from statistics import median + + DotExample = namedtuple('DotExample', ('x', 'y', 'target_sumprod', 'condition')) + + def DotExact(x, y): + vec1 = map(Fraction, x) + vec2 = map(Fraction, y) + return sum(starmap(operator.mul, zip(vec1, vec2, strict=True))) + + def Condition(x, y): + return 2.0 * DotExact(map(abs, x), map(abs, y)) / abs(DotExact(x, y)) + + def linspace(lo, hi, n): + width = (hi - lo) / (n - 1) + return [lo + width * i for i in range(n)] + + def GenDot(n, c): + """ Algorithm 6.1 (GenDot) works as follows. The condition number (5.7) of + the dot product xT y is proportional to the degree of cancellation. In + order to achieve a prescribed cancellation, we generate the first half of + the vectors x and y randomly within a large exponent range. This range is + chosen according to the anticipated condition number. The second half of x + and y is then constructed choosing xi randomly with decreasing exponent, + and calculating yi such that some cancellation occurs. Finally, we permute + the vectors x, y randomly and calculate the achieved condition number. + """ + + assert n >= 6 + n2 = n // 2 + x = [0.0] * n + y = [0.0] * n + b = log2(c) + + # First half with exponents from 0 to |_b/2_| and random ints in between + e = choices(range(int(b/2)), k=n2) + e[0] = int(b / 2) + 1 + e[-1] = 0.0 + + x[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e] + y[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e] + + # Second half + e = list(map(round, linspace(b/2, 0.0 , n-n2))) + for i in range(n2, n): + x[i] = uniform(-1.0, 1.0) * exp2(e[i - n2]) + y[i] = (uniform(-1.0, 1.0) * exp2(e[i - n2]) - DotExact(x, y)) / x[i] + + # Shuffle + pairs = list(zip(x, y)) + shuffle(pairs) + x, y = zip(*pairs) + + return DotExample(x, y, DotExact(x, y), Condition(x, y)) + + def RelativeError(res, ex): + x, y, target_sumprod, condition = ex + n = DotExact(list(x) + [-res], list(y) + [1]) + return fabs(n / target_sumprod) + + def Trial(dotfunc, c, n): + ex = GenDot(10, c) + res = dotfunc(ex.x, ex.y) + return RelativeError(res, ex) + + times = 1000 # Number of trials + n = 20 # Length of vectors + c = 1e30 # Target condition number + + # If the following test fails, it means that the C math library + # implementation of fma() is not compliant with the C99 standard + # and is inaccurate. To solve this problem, make a new build + # with the symbol UNRELIABLE_FMA defined. That will enable a + # slower but accurate code path that avoids the fma() call. + relative_err = median(Trial(math.sumprod, c, n) for i in range(times)) + self.assertLess(relative_err, 1e-16) + def testModf(self): self.assertRaises(TypeError, math.modf) @@ -1235,6 +1570,7 @@ def testPow(self): self.assertTrue(math.isnan(math.pow(2, NAN))) self.assertTrue(math.isnan(math.pow(0, NAN))) self.assertEqual(math.pow(1, NAN), 1) + self.assertRaises(OverflowError, math.pow, 1e+100, 1e+100) # pow(0., x) self.assertEqual(math.pow(0., INF), 0.) @@ -1550,7 +1886,7 @@ def testTan(self): try: self.assertTrue(math.isnan(math.tan(INF))) self.assertTrue(math.isnan(math.tan(NINF))) - except: + except ValueError: self.assertRaises(ValueError, math.tan, INF) self.assertRaises(ValueError, math.tan, NINF) self.assertTrue(math.isnan(math.tan(NAN))) @@ -1591,6 +1927,8 @@ def __trunc__(self): return 23 class TestNoTrunc: pass + class TestBadTrunc: + __trunc__ = BadDescr() self.assertEqual(math.trunc(TestTrunc()), 23) self.assertEqual(math.trunc(FloatTrunc()), 23) @@ -1599,6 +1937,7 @@ class TestNoTrunc: self.assertRaises(TypeError, math.trunc, 1, 2) self.assertRaises(TypeError, math.trunc, FloatLike(23.5)) self.assertRaises(TypeError, math.trunc, TestNoTrunc()) + self.assertRaises(ValueError, math.trunc, TestBadTrunc()) def testIsfinite(self): self.assertTrue(math.isfinite(0.0)) @@ -1626,11 +1965,11 @@ def testIsinf(self): self.assertFalse(math.isinf(0.)) self.assertFalse(math.isinf(1.)) - @requires_IEEE_754 def test_nan_constant(self): + # `math.nan` must be a quiet NaN with positive sign bit self.assertTrue(math.isnan(math.nan)) + self.assertEqual(math.copysign(1., math.nan), 1.) - @requires_IEEE_754 def test_inf_constant(self): self.assertTrue(math.isinf(math.inf)) self.assertGreater(math.inf, 0.0) @@ -1719,6 +2058,13 @@ def test_testfile(self): except OverflowError: result = 'OverflowError' + # C99+ says for math.h's sqrt: If the argument is +∞ or ±0, it is + # returned, unmodified. On another hand, for csqrt: If z is ±0+0i, + # the result is +0+0i. Lets correct zero sign of er to follow + # first convention. + if id in ['sqrt0002', 'sqrt0003', 'sqrt1001', 'sqrt1023']: + er = math.copysign(er, ar) + # Default tolerances ulp_tol, abs_tol = 5, 0.0 @@ -1802,6 +2148,8 @@ def test_mtestfile(self): '\n '.join(failures)) def test_prod(self): + from fractions import Fraction as F + prod = math.prod self.assertEqual(prod([]), 1) self.assertEqual(prod([], start=5), 5) @@ -1813,6 +2161,14 @@ def test_prod(self): self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) + self.assertEqual(prod([1., F(3, 2)]), 1.5) + + # Error in multiplication + class BadMultiply: + def __rmul__(self, other): + raise RuntimeError + with self.assertRaises(RuntimeError): + prod([10., BadMultiply()]) # Test overflow in fast-path for integers self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) @@ -2044,11 +2400,20 @@ def test_nextafter(self): float.fromhex('0x1.fffffffffffffp-1')) self.assertEqual(math.nextafter(1.0, INF), float.fromhex('0x1.0000000000001p+0')) + self.assertEqual(math.nextafter(1.0, -INF, steps=1), + float.fromhex('0x1.fffffffffffffp-1')) + self.assertEqual(math.nextafter(1.0, INF, steps=1), + float.fromhex('0x1.0000000000001p+0')) + self.assertEqual(math.nextafter(1.0, -INF, steps=3), + float.fromhex('0x1.ffffffffffffdp-1')) + self.assertEqual(math.nextafter(1.0, INF, steps=3), + float.fromhex('0x1.0000000000003p+0')) # x == y: y is returned - self.assertEqual(math.nextafter(2.0, 2.0), 2.0) - self.assertEqualSign(math.nextafter(-0.0, +0.0), +0.0) - self.assertEqualSign(math.nextafter(+0.0, -0.0), -0.0) + for steps in range(1, 5): + self.assertEqual(math.nextafter(2.0, 2.0, steps=steps), 2.0) + self.assertEqualSign(math.nextafter(-0.0, +0.0, steps=steps), +0.0) + self.assertEqualSign(math.nextafter(+0.0, -0.0, steps=steps), -0.0) # around 0.0 smallest_subnormal = sys.float_info.min * sys.float_info.epsilon @@ -2073,6 +2438,11 @@ def test_nextafter(self): self.assertIsNaN(math.nextafter(1.0, NAN)) self.assertIsNaN(math.nextafter(NAN, NAN)) + self.assertEqual(1.0, math.nextafter(1.0, INF, steps=0)) + with self.assertRaises(ValueError): + math.nextafter(1.0, INF, steps=-1) + + @requires_IEEE_754 def test_ulp(self): self.assertEqual(math.ulp(1.0), sys.float_info.epsilon) @@ -2112,6 +2482,14 @@ def __float__(self): # argument to a float. self.assertFalse(getattr(y, "converted", False)) + def test_input_exceptions(self): + self.assertRaises(TypeError, math.exp, "spam") + self.assertRaises(TypeError, math.erf, "spam") + self.assertRaises(TypeError, math.atan2, "spam", 1.0) + self.assertRaises(TypeError, math.atan2, 1.0, "spam") + self.assertRaises(TypeError, math.atan2, 1.0) + self.assertRaises(TypeError, math.atan2, 1.0, 2.0, 3.0) + # Custom assertions. def assertIsNaN(self, value): @@ -2252,7 +2630,7 @@ def test_fractions(self): def load_tests(loader, tests, pattern): from doctest import DocFileSuite - # tests.addTest(DocFileSuite("ieee754.txt")) + tests.addTest(DocFileSuite("ieee754.txt")) return tests if __name__ == '__main__': diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index 70bfbf09b5..2cceaea244 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -22,8 +22,6 @@ def randomlist(self, n): """Helper function to make a list of random numbers""" return [self.gen.random() for i in range(n)] - # TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate' - @unittest.expectedFailure def test_autoseed(self): self.gen.seed() state1 = self.gen.getstate() @@ -32,8 +30,6 @@ def test_autoseed(self): state2 = self.gen.getstate() self.assertNotEqual(state1, state2) - # TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate' - @unittest.expectedFailure def test_saverestore(self): N = 1000 self.gen.seed() @@ -60,7 +56,6 @@ def __hash__(self): self.assertRaises(TypeError, self.gen.seed, 1, 2, 3, 4) self.assertRaises(TypeError, type(self.gen), []) - @unittest.skip("TODO: RUSTPYTHON, TypeError: Expected type 'bytes', not 'bytearray'") def test_seed_no_mutate_bug_44018(self): a = bytearray(b'1234') self.gen.seed(a) @@ -386,8 +381,6 @@ def test_getrandbits(self): self.assertRaises(ValueError, self.gen.getrandbits, -1) self.assertRaises(TypeError, self.gen.getrandbits, 10.1) - # TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate' - @unittest.expectedFailure def test_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): state = pickle.dumps(self.gen, proto) @@ -396,8 +389,6 @@ def test_pickling(self): restoredseq = [newgen.random() for i in range(10)] self.assertEqual(origseq, restoredseq) - # TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate' - @unittest.expectedFailure def test_bug_1727780(self): # verify that version-2-pickles can be loaded # fine, whether they are created on 32-bit or 64-bit @@ -600,11 +591,6 @@ def test_bug_42008(self): class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase): gen = random.Random() - # TODO: RUSTPYTHON, TypeError: Expected type 'bytes', not 'bytearray' - @unittest.expectedFailure - def test_seed_no_mutate_bug_44018(self): # TODO: RUSTPYTHON, remove when this passes - super().test_seed_no_mutate_bug_44018() # TODO: RUSTPYTHON, remove when this passes - def test_guaranteed_stable(self): # These sequences are guaranteed to stay the same across versions of python self.gen.seed(3456147, version=1) @@ -675,8 +661,6 @@ def test_bug_31482(self): def test_setstate_first_arg(self): self.assertRaises(ValueError, self.gen.setstate, (1, None, None)) - # TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate' - @unittest.expectedFailure def test_setstate_middle_arg(self): start_state = self.gen.getstate() # Wrong type, s/b tuple @@ -1012,8 +996,6 @@ def gamma(z, sqrt2pi=(2.0*pi)**0.5): ]) class TestDistributions(unittest.TestCase): - # TODO: RUSTPYTHON ValueError: math domain error - @unittest.expectedFailure def test_zeroinputs(self): # Verify that distributions can handle a series of zero inputs' g = random.Random() @@ -1284,15 +1266,6 @@ def test_betavariate_return_zero(self, gammavariate_mock): class TestRandomSubclassing(unittest.TestCase): - # TODO: RUSTPYTHON Unexpected keyword argument newarg - @unittest.expectedFailure - def test_random_subclass_with_kwargs(self): - # SF bug #1486663 -- this used to erroneously raise a TypeError - class Subclass(random.Random): - def __init__(self, newarg=None): - random.Random.__init__(self) - Subclass(newarg=1) - def test_subclasses_overriding_methods(self): # Subclasses with an overridden random, but only the original # getrandbits method should not rely on getrandbits in for randrange, diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py new file mode 100644 index 0000000000..a36d7bbe2a --- /dev/null +++ b/Lib/test/test_smtplib.py @@ -0,0 +1,1568 @@ +import base64 +import email.mime.text +from email.message import EmailMessage +from email.base64mime import body_encode as encode_base64 +import email.utils +import hashlib +import hmac +import socket +import smtplib +import io +import re +import sys +import time +import select +import errno +import textwrap +import threading + +import unittest +from test import support, mock_socket +from test.support import hashlib_helper +from test.support import socket_helper +from test.support import threading_helper +from test.support import asyncore +from test.support import smtpd +from unittest.mock import Mock + + +support.requires_working_socket(module=True) + +HOST = socket_helper.HOST + +if sys.platform == 'darwin': + # select.poll returns a select.POLLHUP at the end of the tests + # on darwin, so just ignore it + def handle_expt(self): + pass + smtpd.SMTPChannel.handle_expt = handle_expt + + +def server(evt, buf, serv): + serv.listen() + evt.set() + try: + conn, addr = serv.accept() + except TimeoutError: + pass + else: + n = 500 + while buf and n > 0: + r, w, e = select.select([], [conn], []) + if w: + sent = conn.send(buf) + buf = buf[sent:] + + n -= 1 + + conn.close() + finally: + serv.close() + evt.set() + +class GeneralTests: + + def setUp(self): + smtplib.socket = mock_socket + self.port = 25 + + def tearDown(self): + smtplib.socket = socket + + # This method is no longer used but is retained for backward compatibility, + # so test to make sure it still works. + def testQuoteData(self): + teststr = "abc\n.jkl\rfoo\r\n..blue" + expected = "abc\r\n..jkl\r\nfoo\r\n...blue" + self.assertEqual(expected, smtplib.quotedata(teststr)) + + def testBasic1(self): + mock_socket.reply_with(b"220 Hola mundo") + # connects + client = self.client(HOST, self.port) + client.close() + + def testSourceAddress(self): + mock_socket.reply_with(b"220 Hola mundo") + # connects + client = self.client(HOST, self.port, + source_address=('127.0.0.1',19876)) + self.assertEqual(client.source_address, ('127.0.0.1', 19876)) + client.close() + + def testBasic2(self): + mock_socket.reply_with(b"220 Hola mundo") + # connects, include port in host name + client = self.client("%s:%s" % (HOST, self.port)) + client.close() + + def testLocalHostName(self): + mock_socket.reply_with(b"220 Hola mundo") + # check that supplied local_hostname is used + client = self.client(HOST, self.port, local_hostname="testhost") + self.assertEqual(client.local_hostname, "testhost") + client.close() + + def testTimeoutDefault(self): + mock_socket.reply_with(b"220 Hola mundo") + self.assertIsNone(mock_socket.getdefaulttimeout()) + mock_socket.setdefaulttimeout(30) + self.assertEqual(mock_socket.getdefaulttimeout(), 30) + try: + client = self.client(HOST, self.port) + finally: + mock_socket.setdefaulttimeout(None) + self.assertEqual(client.sock.gettimeout(), 30) + client.close() + + def testTimeoutNone(self): + mock_socket.reply_with(b"220 Hola mundo") + self.assertIsNone(socket.getdefaulttimeout()) + socket.setdefaulttimeout(30) + try: + client = self.client(HOST, self.port, timeout=None) + finally: + socket.setdefaulttimeout(None) + self.assertIsNone(client.sock.gettimeout()) + client.close() + + def testTimeoutZero(self): + mock_socket.reply_with(b"220 Hola mundo") + with self.assertRaises(ValueError): + self.client(HOST, self.port, timeout=0) + + def testTimeoutValue(self): + mock_socket.reply_with(b"220 Hola mundo") + client = self.client(HOST, self.port, timeout=30) + self.assertEqual(client.sock.gettimeout(), 30) + client.close() + + def test_debuglevel(self): + mock_socket.reply_with(b"220 Hello world") + client = self.client() + client.set_debuglevel(1) + with support.captured_stderr() as stderr: + client.connect(HOST, self.port) + client.close() + expected = re.compile(r"^connect:", re.MULTILINE) + self.assertRegex(stderr.getvalue(), expected) + + def test_debuglevel_2(self): + mock_socket.reply_with(b"220 Hello world") + client = self.client() + client.set_debuglevel(2) + with support.captured_stderr() as stderr: + client.connect(HOST, self.port) + client.close() + expected = re.compile(r"^\d{2}:\d{2}:\d{2}\.\d{6} connect: ", + re.MULTILINE) + self.assertRegex(stderr.getvalue(), expected) + + +class SMTPGeneralTests(GeneralTests, unittest.TestCase): + + client = smtplib.SMTP + + +class LMTPGeneralTests(GeneralTests, unittest.TestCase): + + client = smtplib.LMTP + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), "test requires Unix domain socket") + def testUnixDomainSocketTimeoutDefault(self): + local_host = '/some/local/lmtp/delivery/program' + mock_socket.reply_with(b"220 Hello world") + try: + client = self.client(local_host, self.port) + finally: + mock_socket.setdefaulttimeout(None) + self.assertIsNone(client.sock.gettimeout()) + client.close() + + def testTimeoutZero(self): + super().testTimeoutZero() + local_host = '/some/local/lmtp/delivery/program' + with self.assertRaises(ValueError): + self.client(local_host, timeout=0) + +# Test server thread using the specified SMTP server class +def debugging_server(serv, serv_evt, client_evt): + serv_evt.set() + + try: + if hasattr(select, 'poll'): + poll_fun = asyncore.poll2 + else: + poll_fun = asyncore.poll + + n = 1000 + while asyncore.socket_map and n > 0: + poll_fun(0.01, asyncore.socket_map) + + # when the client conversation is finished, it will + # set client_evt, and it's then ok to kill the server + if client_evt.is_set(): + serv.close() + break + + n -= 1 + + except TimeoutError: + pass + finally: + if not client_evt.is_set(): + # allow some time for the client to read the result + time.sleep(0.5) + serv.close() + asyncore.close_all() + serv_evt.set() + +MSG_BEGIN = '---------- MESSAGE FOLLOWS ----------\n' +MSG_END = '------------ END MESSAGE ------------\n' + +# NOTE: Some SMTP objects in the tests below are created with a non-default +# local_hostname argument to the constructor, since (on some systems) the FQDN +# lookup caused by the default local_hostname sometimes takes so long that the +# test server times out, causing the test to fail. + +# Test behavior of smtpd.DebuggingServer +class DebuggingServerTests(unittest.TestCase): + + maxDiff = None + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + # temporarily replace sys.stdout to capture DebuggingServer output + self.old_stdout = sys.stdout + self.output = io.StringIO() + sys.stdout = self.output + + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Capture SMTPChannel debug output + self.old_DEBUGSTREAM = smtpd.DEBUGSTREAM + smtpd.DEBUGSTREAM = io.StringIO() + # Pick a random unused port by passing 0 for the port number + self.serv = smtpd.DebuggingServer((HOST, 0), ('nowhere', -1), + decode_data=True) + # Keep a note of what server host and port were assigned + self.host, self.port = self.serv.socket.getsockname()[:2] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + # restore sys.stdout + sys.stdout = self.old_stdout + # restore DEBUGSTREAM + smtpd.DEBUGSTREAM.close() + smtpd.DEBUGSTREAM = self.old_DEBUGSTREAM + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def get_output_without_xpeer(self): + test_output = self.output.getvalue() + return re.sub(r'(.*?)^X-Peer:\s*\S+\n(.*)', r'\1\2', + test_output, flags=re.MULTILINE|re.DOTALL) + + def testBasic(self): + # connect + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.quit() + + def testSourceAddress(self): + # connect + src_port = socket_helper.find_unused_port() + try: + smtp = smtplib.SMTP(self.host, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT, + source_address=(self.host, src_port)) + self.addCleanup(smtp.close) + self.assertEqual(smtp.source_address, (self.host, src_port)) + self.assertEqual(smtp.local_hostname, 'localhost') + smtp.quit() + except OSError as e: + if e.errno == errno.EADDRINUSE: + self.skipTest("couldn't bind to source port %d" % src_port) + raise + + def testNOOP(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (250, b'OK') + self.assertEqual(smtp.noop(), expected) + smtp.quit() + + def testRSET(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (250, b'OK') + self.assertEqual(smtp.rset(), expected) + smtp.quit() + + def testELHO(self): + # EHLO isn't implemented in DebuggingServer + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (250, b'\nSIZE 33554432\nHELP') + self.assertEqual(smtp.ehlo(), expected) + smtp.quit() + + def testEXPNNotImplemented(self): + # EXPN isn't implemented in DebuggingServer + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (502, b'EXPN not implemented') + smtp.putcmd('EXPN') + self.assertEqual(smtp.getreply(), expected) + smtp.quit() + + def test_issue43124_putcmd_escapes_newline(self): + # see: https://bugs.python.org/issue43124 + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(ValueError) as exc: + smtp.putcmd('helo\nX-INJECTED') + self.assertIn("prohibited newline characters", str(exc.exception)) + smtp.quit() + + def testVRFY(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (252, b'Cannot VRFY user, but will accept message ' + \ + b'and attempt delivery') + self.assertEqual(smtp.vrfy('nobody@nowhere.com'), expected) + self.assertEqual(smtp.verify('nobody@nowhere.com'), expected) + smtp.quit() + + def testSecondHELO(self): + # check that a second HELO returns a message that it's a duplicate + # (this behavior is specific to smtpd.SMTPChannel) + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.helo() + expected = (503, b'Duplicate HELO/EHLO') + self.assertEqual(smtp.helo(), expected) + smtp.quit() + + def testHELP(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + self.assertEqual(smtp.help(), b'Supported commands: EHLO HELO MAIL ' + \ + b'RCPT DATA RSET NOOP QUIT VRFY') + smtp.quit() + + def testSend(self): + # connect and send mail + m = 'A test message' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('John', 'Sally', m) + # XXX(nnorwitz): this test is flaky and dies with a bad file descriptor + # in asyncore. This sleep might help, but should really be fixed + # properly by using an Event variable. + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + + def testSendBinary(self): + m = b'A test message' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('John', 'Sally', m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.decode('ascii'), MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + + def testSendNeedingDotQuote(self): + # Issue 12283 + m = '.A test\n.mes.sage.' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('John', 'Sally', m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + + def test_issue43124_escape_localhostname(self): + # see: https://bugs.python.org/issue43124 + # connect and send mail + m = 'wazzuuup\nlinetwo' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='hi\nX-INJECTED', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(ValueError) as exc: + smtp.sendmail("hi@me.com", "you@me.com", m) + self.assertIn( + "prohibited newline characters: ehlo hi\\nX-INJECTED", + str(exc.exception), + ) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + debugout = smtpd.DEBUGSTREAM.getvalue() + self.assertNotIn("X-INJECTED", debugout) + + def test_issue43124_escape_options(self): + # see: https://bugs.python.org/issue43124 + # connect and send mail + m = 'wazzuuup\nlinetwo' + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + self.addCleanup(smtp.close) + smtp.sendmail("hi@me.com", "you@me.com", m) + with self.assertRaises(ValueError) as exc: + smtp.mail("hi@me.com", ["X-OPTION\nX-INJECTED-1", "X-OPTION2\nX-INJECTED-2"]) + msg = str(exc.exception) + self.assertIn("prohibited newline characters", msg) + self.assertIn("X-OPTION\\nX-INJECTED-1 X-OPTION2\\nX-INJECTED-2", msg) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + debugout = smtpd.DEBUGSTREAM.getvalue() + self.assertNotIn("X-OPTION", debugout) + self.assertNotIn("X-OPTION2", debugout) + self.assertNotIn("X-INJECTED-1", debugout) + self.assertNotIn("X-INJECTED-2", debugout) + + def testSendNullSender(self): + m = 'A test message' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('<>', 'Sally', m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: <>$", re.MULTILINE) + self.assertRegex(debugout, sender) + + def testSendMessage(self): + m = email.mime.text.MIMEText('A test message') + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m, from_addr='John', to_addrs='Sally') + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds as figuring out + # exactly what IP address format is put there is not easy (and + # irrelevant to our test). Typically 127.0.0.1 or ::1, but it is + # not always the same as socket.gethostbyname(HOST). :( + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + + def testSendMessageWithAddresses(self): + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John' + m['CC'] = 'Sally, Fred' + m['Bcc'] = 'John Root , "Dinsdale" ' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + # make sure the Bcc header is still in the message. + self.assertEqual(m['Bcc'], 'John Root , "Dinsdale" ' + '') + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + # The Bcc header should not be transmitted. + del m['Bcc'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: foo@bar.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Sally', 'Fred', 'root@localhost', + 'warped@silly.walks.com'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageWithSomeAddresses(self): + # Make sure nothing breaks if not all of the three 'to' headers exist + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John, Dinsdale' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: foo@bar.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Dinsdale'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageWithSpecifiedAddresses(self): + # Make sure addresses specified in call override those in message. + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John, Dinsdale' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m, from_addr='joe@example.com', to_addrs='foo@example.net') + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: joe@example.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Dinsdale'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertNotRegex(debugout, to_addr) + recip = re.compile(r"^recips: .*'foo@example.net'.*$", re.MULTILINE) + self.assertRegex(debugout, recip) + + def testSendMessageWithMultipleFrom(self): + # Sender overrides To + m = email.mime.text.MIMEText('A test message') + m['From'] = 'Bernard, Bianca' + m['Sender'] = 'the_rescuers@Rescue-Aid-Society.com' + m['To'] = 'John, Dinsdale' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: the_rescuers@Rescue-Aid-Society.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Dinsdale'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageResent(self): + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John' + m['CC'] = 'Sally, Fred' + m['Bcc'] = 'John Root , "Dinsdale" ' + m['Resent-Date'] = 'Thu, 1 Jan 1970 17:42:00 +0000' + m['Resent-From'] = 'holy@grail.net' + m['Resent-To'] = 'Martha , Jeff' + m['Resent-Bcc'] = 'doe@losthope.net' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # The Resent-Bcc headers are deleted before serialization. + del m['Bcc'] + del m['Resent-Bcc'] + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: holy@grail.net$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('my_mom@great.cooker.com', 'Jeff', 'doe@losthope.net'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageMultipleResentRaises(self): + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John' + m['CC'] = 'Sally, Fred' + m['Bcc'] = 'John Root , "Dinsdale" ' + m['Resent-Date'] = 'Thu, 1 Jan 1970 17:42:00 +0000' + m['Resent-From'] = 'holy@grail.net' + m['Resent-To'] = 'Martha , Jeff' + m['Resent-Bcc'] = 'doe@losthope.net' + m['Resent-Date'] = 'Thu, 2 Jan 1970 17:42:00 +0000' + m['Resent-To'] = 'holy@grail.net' + m['Resent-From'] = 'Martha , Jeff' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(ValueError): + smtp.send_message(m) + smtp.close() + +class NonConnectingTests(unittest.TestCase): + + def testNotConnected(self): + # Test various operations on an unconnected SMTP object that + # should raise exceptions (at present the attempt in SMTP.send + # to reference the nonexistent 'sock' attribute of the SMTP object + # causes an AttributeError) + smtp = smtplib.SMTP() + self.assertRaises(smtplib.SMTPServerDisconnected, smtp.ehlo) + self.assertRaises(smtplib.SMTPServerDisconnected, + smtp.send, 'test msg') + + def testNonnumericPort(self): + # check that non-numeric port raises OSError + self.assertRaises(OSError, smtplib.SMTP, + "localhost", "bogus") + self.assertRaises(OSError, smtplib.SMTP, + "localhost:bogus") + + def testSockAttributeExists(self): + # check that sock attribute is present outside of a connect() call + # (regression test, the previous behavior raised an + # AttributeError: 'SMTP' object has no attribute 'sock') + with smtplib.SMTP() as smtp: + self.assertIsNone(smtp.sock) + + +class DefaultArgumentsTests(unittest.TestCase): + + def setUp(self): + self.msg = EmailMessage() + self.msg['From'] = 'Páolo ' + self.smtp = smtplib.SMTP() + self.smtp.ehlo = Mock(return_value=(200, 'OK')) + self.smtp.has_extn, self.smtp.sendmail = Mock(), Mock() + + def testSendMessage(self): + expected_mail_options = ('SMTPUTF8', 'BODY=8BITMIME') + self.smtp.send_message(self.msg) + self.smtp.send_message(self.msg) + self.assertEqual(self.smtp.sendmail.call_args_list[0][0][3], + expected_mail_options) + self.assertEqual(self.smtp.sendmail.call_args_list[1][0][3], + expected_mail_options) + + def testSendMessageWithMailOptions(self): + mail_options = ['STARTTLS'] + expected_mail_options = ('STARTTLS', 'SMTPUTF8', 'BODY=8BITMIME') + self.smtp.send_message(self.msg, None, None, mail_options) + self.assertEqual(mail_options, ['STARTTLS']) + self.assertEqual(self.smtp.sendmail.call_args_list[0][0][3], + expected_mail_options) + + +# test response of client to a non-successful HELO message +class BadHELOServerTests(unittest.TestCase): + + def setUp(self): + smtplib.socket = mock_socket + mock_socket.reply_with(b"199 no hello for you!") + self.old_stdout = sys.stdout + self.output = io.StringIO() + sys.stdout = self.output + self.port = 25 + + def tearDown(self): + smtplib.socket = socket + sys.stdout = self.old_stdout + + def testFailingHELO(self): + self.assertRaises(smtplib.SMTPConnectError, smtplib.SMTP, + HOST, self.port, 'localhost', 3) + + +class TooLongLineTests(unittest.TestCase): + respdata = b'250 OK' + (b'.' * smtplib._MAXLINE * 2) + b'\n' + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.old_stdout = sys.stdout + self.output = io.StringIO() + sys.stdout = self.output + + self.evt = threading.Event() + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.settimeout(15) + self.port = socket_helper.bind_port(self.sock) + servargs = (self.evt, self.respdata, self.sock) + self.thread = threading.Thread(target=server, args=servargs) + self.thread.start() + self.evt.wait() + self.evt.clear() + + def tearDown(self): + self.evt.wait() + sys.stdout = self.old_stdout + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def testLineTooLong(self): + self.assertRaises(smtplib.SMTPResponseException, smtplib.SMTP, + HOST, self.port, 'localhost', 3) + + +sim_users = {'Mr.A@somewhere.com':'John A', + 'Ms.B@xn--fo-fka.com':'Sally B', + 'Mrs.C@somewhereesle.com':'Ruth C', + } + +sim_auth = ('Mr.A@somewhere.com', 'somepassword') +sim_cram_md5_challenge = ('PENCeUxFREJoU0NnbmhNWitOMjNGNn' + 'dAZWx3b29kLmlubm9zb2Z0LmNvbT4=') +sim_lists = {'list-1':['Mr.A@somewhere.com','Mrs.C@somewhereesle.com'], + 'list-2':['Ms.B@xn--fo-fka.com',], + } + +# Simulated SMTP channel & server +class ResponseException(Exception): pass +class SimSMTPChannel(smtpd.SMTPChannel): + + quit_response = None + mail_response = None + rcpt_response = None + data_response = None + rcpt_count = 0 + rset_count = 0 + disconnect = 0 + AUTH = 99 # Add protocol state to enable auth testing. + authenticated_user = None + + def __init__(self, extra_features, *args, **kw): + self._extrafeatures = ''.join( + [ "250-{0}\r\n".format(x) for x in extra_features ]) + super(SimSMTPChannel, self).__init__(*args, **kw) + + # AUTH related stuff. It would be nice if support for this were in smtpd. + def found_terminator(self): + if self.smtp_state == self.AUTH: + line = self._emptystring.join(self.received_lines) + print('Data:', repr(line), file=smtpd.DEBUGSTREAM) + self.received_lines = [] + try: + self.auth_object(line) + except ResponseException as e: + self.smtp_state = self.COMMAND + self.push('%s %s' % (e.smtp_code, e.smtp_error)) + return + super().found_terminator() + + + def smtp_AUTH(self, arg): + if not self.seen_greeting: + self.push('503 Error: send EHLO first') + return + if not self.extended_smtp or 'AUTH' not in self._extrafeatures: + self.push('500 Error: command "AUTH" not recognized') + return + if self.authenticated_user is not None: + self.push( + '503 Bad sequence of commands: already authenticated') + return + args = arg.split() + if len(args) not in [1, 2]: + self.push('501 Syntax: AUTH [initial-response]') + return + auth_object_name = '_auth_%s' % args[0].lower().replace('-', '_') + try: + self.auth_object = getattr(self, auth_object_name) + except AttributeError: + self.push('504 Command parameter not implemented: unsupported ' + ' authentication mechanism {!r}'.format(auth_object_name)) + return + self.smtp_state = self.AUTH + self.auth_object(args[1] if len(args) == 2 else None) + + def _authenticated(self, user, valid): + if valid: + self.authenticated_user = user + self.push('235 Authentication Succeeded') + else: + self.push('535 Authentication credentials invalid') + self.smtp_state = self.COMMAND + + def _decode_base64(self, string): + return base64.decodebytes(string.encode('ascii')).decode('utf-8') + + def _auth_plain(self, arg=None): + if arg is None: + self.push('334 ') + else: + logpass = self._decode_base64(arg) + try: + *_, user, password = logpass.split('\0') + except ValueError as e: + self.push('535 Splitting response {!r} into user and password' + ' failed: {}'.format(logpass, e)) + return + self._authenticated(user, password == sim_auth[1]) + + def _auth_login(self, arg=None): + if arg is None: + # base64 encoded 'Username:' + self.push('334 VXNlcm5hbWU6') + elif not hasattr(self, '_auth_login_user'): + self._auth_login_user = self._decode_base64(arg) + # base64 encoded 'Password:' + self.push('334 UGFzc3dvcmQ6') + else: + password = self._decode_base64(arg) + self._authenticated(self._auth_login_user, password == sim_auth[1]) + del self._auth_login_user + + def _auth_buggy(self, arg=None): + # This AUTH mechanism will 'trap' client in a neverending 334 + # base64 encoded 'BuGgYbUgGy' + self.push('334 QnVHZ1liVWdHeQ==') + + def _auth_cram_md5(self, arg=None): + if arg is None: + self.push('334 {}'.format(sim_cram_md5_challenge)) + else: + logpass = self._decode_base64(arg) + try: + user, hashed_pass = logpass.split() + except ValueError as e: + self.push('535 Splitting response {!r} into user and password ' + 'failed: {}'.format(logpass, e)) + return False + valid_hashed_pass = hmac.HMAC( + sim_auth[1].encode('ascii'), + self._decode_base64(sim_cram_md5_challenge).encode('ascii'), + 'md5').hexdigest() + self._authenticated(user, hashed_pass == valid_hashed_pass) + # end AUTH related stuff. + + def smtp_EHLO(self, arg): + resp = ('250-testhost\r\n' + '250-EXPN\r\n' + '250-SIZE 20000000\r\n' + '250-STARTTLS\r\n' + '250-DELIVERBY\r\n') + resp = resp + self._extrafeatures + '250 HELP' + self.push(resp) + self.seen_greeting = arg + self.extended_smtp = True + + def smtp_VRFY(self, arg): + # For max compatibility smtplib should be sending the raw address. + if arg in sim_users: + self.push('250 %s %s' % (sim_users[arg], smtplib.quoteaddr(arg))) + else: + self.push('550 No such user: %s' % arg) + + def smtp_EXPN(self, arg): + list_name = arg.lower() + if list_name in sim_lists: + user_list = sim_lists[list_name] + for n, user_email in enumerate(user_list): + quoted_addr = smtplib.quoteaddr(user_email) + if n < len(user_list) - 1: + self.push('250-%s %s' % (sim_users[user_email], quoted_addr)) + else: + self.push('250 %s %s' % (sim_users[user_email], quoted_addr)) + else: + self.push('550 No access for you!') + + def smtp_QUIT(self, arg): + if self.quit_response is None: + super(SimSMTPChannel, self).smtp_QUIT(arg) + else: + self.push(self.quit_response) + self.close_when_done() + + def smtp_MAIL(self, arg): + if self.mail_response is None: + super().smtp_MAIL(arg) + else: + self.push(self.mail_response) + if self.disconnect: + self.close_when_done() + + def smtp_RCPT(self, arg): + if self.rcpt_response is None: + super().smtp_RCPT(arg) + return + self.rcpt_count += 1 + self.push(self.rcpt_response[self.rcpt_count-1]) + + def smtp_RSET(self, arg): + self.rset_count += 1 + super().smtp_RSET(arg) + + def smtp_DATA(self, arg): + if self.data_response is None: + super().smtp_DATA(arg) + else: + self.push(self.data_response) + + def handle_error(self): + raise + + +class SimSMTPServer(smtpd.SMTPServer): + + channel_class = SimSMTPChannel + + def __init__(self, *args, **kw): + self._extra_features = [] + self._addresses = {} + smtpd.SMTPServer.__init__(self, *args, **kw) + + def handle_accepted(self, conn, addr): + self._SMTPchannel = self.channel_class( + self._extra_features, self, conn, addr, + decode_data=self._decode_data) + + def process_message(self, peer, mailfrom, rcpttos, data): + self._addresses['from'] = mailfrom + self._addresses['tos'] = rcpttos + + def add_feature(self, feature): + self._extra_features.append(feature) + + def handle_error(self): + raise + + +# Test various SMTP & ESMTP commands/behaviors that require a simulated server +# (i.e., something with more features than DebuggingServer) +class SMTPSimTests(unittest.TestCase): + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPServer((HOST, 0), ('nowhere', -1), decode_data=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def testBasic(self): + # smoke test + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.quit() + + def testEHLO(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + # no features should be present before the EHLO + self.assertEqual(smtp.esmtp_features, {}) + + # features expected from the test server + expected_features = {'expn':'', + 'size': '20000000', + 'starttls': '', + 'deliverby': '', + 'help': '', + } + + smtp.ehlo() + self.assertEqual(smtp.esmtp_features, expected_features) + for k in expected_features: + self.assertTrue(smtp.has_extn(k)) + self.assertFalse(smtp.has_extn('unsupported-feature')) + smtp.quit() + + def testVRFY(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + for addr_spec, name in sim_users.items(): + expected_known = (250, bytes('%s %s' % + (name, smtplib.quoteaddr(addr_spec)), + "ascii")) + self.assertEqual(smtp.vrfy(addr_spec), expected_known) + + u = 'nobody@nowhere.com' + expected_unknown = (550, ('No such user: %s' % u).encode('ascii')) + self.assertEqual(smtp.vrfy(u), expected_unknown) + smtp.quit() + + def testEXPN(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + for listname, members in sim_lists.items(): + users = [] + for m in members: + users.append('%s %s' % (sim_users[m], smtplib.quoteaddr(m))) + expected_known = (250, bytes('\n'.join(users), "ascii")) + self.assertEqual(smtp.expn(listname), expected_known) + + u = 'PSU-Members-List' + expected_unknown = (550, b'No access for you!') + self.assertEqual(smtp.expn(u), expected_unknown) + smtp.quit() + + def testAUTH_PLAIN(self): + self.serv.add_feature("AUTH PLAIN") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def testAUTH_LOGIN(self): + self.serv.add_feature("AUTH LOGIN") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def testAUTH_LOGIN_initial_response_ok(self): + self.serv.add_feature("AUTH LOGIN") + with smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) as smtp: + smtp.user, smtp.password = sim_auth + smtp.ehlo("test_auth_login") + resp = smtp.auth("LOGIN", smtp.auth_login, initial_response_ok=True) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + + def testAUTH_LOGIN_initial_response_notok(self): + self.serv.add_feature("AUTH LOGIN") + with smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) as smtp: + smtp.user, smtp.password = sim_auth + smtp.ehlo("test_auth_login") + resp = smtp.auth("LOGIN", smtp.auth_login, initial_response_ok=False) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + + def testAUTH_BUGGY(self): + self.serv.add_feature("AUTH BUGGY") + + def auth_buggy(challenge=None): + self.assertEqual(b"BuGgYbUgGy", challenge) + return "\0" + + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT + ) + try: + smtp.user, smtp.password = sim_auth + smtp.ehlo("test_auth_buggy") + expect = r"^Server AUTH mechanism infinite loop.*" + with self.assertRaisesRegex(smtplib.SMTPException, expect) as cm: + smtp.auth("BUGGY", auth_buggy, initial_response_ok=False) + finally: + smtp.close() + + # TODO: RUSTPYTHON + @unittest.expectedFailure + @hashlib_helper.requires_hashdigest('md5', openssl=True) + def testAUTH_CRAM_MD5(self): + self.serv.add_feature("AUTH CRAM-MD5") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + # TODO: RUSTPYTHON + @unittest.expectedFailure + @hashlib_helper.requires_hashdigest('md5', openssl=True) + def testAUTH_multiple(self): + # Test that multiple authentication methods are tried. + self.serv.add_feature("AUTH BOGUS PLAIN LOGIN CRAM-MD5") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_auth_function(self): + supported = {'PLAIN', 'LOGIN'} + try: + hashlib.md5() + except ValueError: + pass + else: + supported.add('CRAM-MD5') + for mechanism in supported: + self.serv.add_feature("AUTH {}".format(mechanism)) + for mechanism in supported: + with self.subTest(mechanism=mechanism): + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.ehlo('foo') + smtp.user, smtp.password = sim_auth[0], sim_auth[1] + method = 'auth_' + mechanism.lower().replace('-', '_') + resp = smtp.auth(mechanism, getattr(smtp, method)) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def test_quit_resets_greeting(self): + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + code, message = smtp.ehlo() + self.assertEqual(code, 250) + self.assertIn('size', smtp.esmtp_features) + smtp.quit() + self.assertNotIn('size', smtp.esmtp_features) + smtp.connect(HOST, self.port) + self.assertNotIn('size', smtp.esmtp_features) + smtp.ehlo_or_helo_if_needed() + self.assertIn('size', smtp.esmtp_features) + smtp.quit() + + def test_with_statement(self): + with smtplib.SMTP(HOST, self.port) as smtp: + code, message = smtp.noop() + self.assertEqual(code, 250) + self.assertRaises(smtplib.SMTPServerDisconnected, smtp.send, b'foo') + with smtplib.SMTP(HOST, self.port) as smtp: + smtp.close() + self.assertRaises(smtplib.SMTPServerDisconnected, smtp.send, b'foo') + + def test_with_statement_QUIT_failure(self): + with self.assertRaises(smtplib.SMTPResponseException) as error: + with smtplib.SMTP(HOST, self.port) as smtp: + smtp.noop() + self.serv._SMTPchannel.quit_response = '421 QUIT FAILED' + self.assertEqual(error.exception.smtp_code, 421) + self.assertEqual(error.exception.smtp_error, b'QUIT FAILED') + + #TODO: add tests for correct AUTH method fallback now that the + #test infrastructure can support it. + + # Issue 17498: make sure _rset does not raise SMTPServerDisconnected exception + def test__rest_from_mail_cmd(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + self.serv._SMTPchannel.mail_response = '451 Requested action aborted' + self.serv._SMTPchannel.disconnect = True + with self.assertRaises(smtplib.SMTPSenderRefused): + smtp.sendmail('John', 'Sally', 'test message') + self.assertIsNone(smtp.sock) + + # Issue 5713: make sure close, not rset, is called if we get a 421 error + def test_421_from_mail_cmd(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + self.serv._SMTPchannel.mail_response = '421 closing connection' + with self.assertRaises(smtplib.SMTPSenderRefused): + smtp.sendmail('John', 'Sally', 'test message') + self.assertIsNone(smtp.sock) + self.assertEqual(self.serv._SMTPchannel.rset_count, 0) + + def test_421_from_rcpt_cmd(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + self.serv._SMTPchannel.rcpt_response = ['250 accepted', '421 closing'] + with self.assertRaises(smtplib.SMTPRecipientsRefused) as r: + smtp.sendmail('John', ['Sally', 'Frank', 'George'], 'test message') + self.assertIsNone(smtp.sock) + self.assertEqual(self.serv._SMTPchannel.rset_count, 0) + self.assertDictEqual(r.exception.args[0], {'Frank': (421, b'closing')}) + + def test_421_from_data_cmd(self): + class MySimSMTPChannel(SimSMTPChannel): + def found_terminator(self): + if self.smtp_state == self.DATA: + self.push('421 closing') + else: + super().found_terminator() + self.serv.channel_class = MySimSMTPChannel + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + with self.assertRaises(smtplib.SMTPDataError): + smtp.sendmail('John@foo.org', ['Sally@foo.org'], 'test message') + self.assertIsNone(smtp.sock) + self.assertEqual(self.serv._SMTPchannel.rcpt_count, 0) + + def test_smtputf8_NotSupportedError_if_no_server_support(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertTrue(smtp.does_esmtp) + self.assertFalse(smtp.has_extn('smtputf8')) + self.assertRaises( + smtplib.SMTPNotSupportedError, + smtp.sendmail, + 'John', 'Sally', '', mail_options=['BODY=8BITMIME', 'SMTPUTF8']) + self.assertRaises( + smtplib.SMTPNotSupportedError, + smtp.mail, 'John', options=['BODY=8BITMIME', 'SMTPUTF8']) + + def test_send_unicode_without_SMTPUTF8(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + self.assertRaises(UnicodeEncodeError, smtp.sendmail, 'Alice', 'Böb', '') + self.assertRaises(UnicodeEncodeError, smtp.mail, 'Älice') + + def test_send_message_error_on_non_ascii_addrs_if_no_smtputf8(self): + # This test is located here and not in the SMTPUTF8SimTests + # class because it needs a "regular" SMTP server to work + msg = EmailMessage() + msg['From'] = "Páolo " + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(smtplib.SMTPNotSupportedError): + smtp.send_message(msg) + + def test_name_field_not_included_in_envelop_addresses(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + + message = EmailMessage() + message['From'] = email.utils.formataddr(('Michaël', 'michael@example.com')) + message['To'] = email.utils.formataddr(('René', 'rene@example.com')) + + self.assertDictEqual(smtp.send_message(message), {}) + + self.assertEqual(self.serv._addresses['from'], 'michael@example.com') + self.assertEqual(self.serv._addresses['tos'], ['rene@example.com']) + + +class SimSMTPUTF8Server(SimSMTPServer): + + def __init__(self, *args, **kw): + # The base SMTP server turns these on automatically, but our test + # server is set up to munge the EHLO response, so we need to provide + # them as well. And yes, the call is to SMTPServer not SimSMTPServer. + self._extra_features = ['SMTPUTF8', '8BITMIME'] + smtpd.SMTPServer.__init__(self, *args, **kw) + + def handle_accepted(self, conn, addr): + self._SMTPchannel = self.channel_class( + self._extra_features, self, conn, addr, + decode_data=self._decode_data, + enable_SMTPUTF8=self.enable_SMTPUTF8, + ) + + def process_message(self, peer, mailfrom, rcpttos, data, mail_options=None, + rcpt_options=None): + self.last_peer = peer + self.last_mailfrom = mailfrom + self.last_rcpttos = rcpttos + self.last_message = data + self.last_mail_options = mail_options + self.last_rcpt_options = rcpt_options + + +class SMTPUTF8SimTests(unittest.TestCase): + + maxDiff = None + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPUTF8Server((HOST, 0), ('nowhere', -1), + decode_data=False, + enable_SMTPUTF8=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def test_test_server_supports_extensions(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertTrue(smtp.does_esmtp) + self.assertTrue(smtp.has_extn('smtputf8')) + + def test_send_unicode_with_SMTPUTF8_via_sendmail(self): + m = '¡a test message containing unicode!'.encode('utf-8') + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('Jőhn', 'Sálly', m, + mail_options=['BODY=8BITMIME', 'SMTPUTF8']) + self.assertEqual(self.serv.last_mailfrom, 'Jőhn') + self.assertEqual(self.serv.last_rcpttos, ['Sálly']) + self.assertEqual(self.serv.last_message, m) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + def test_send_unicode_with_SMTPUTF8_via_low_level_API(self): + m = '¡a test message containing unicode!'.encode('utf-8') + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertEqual( + smtp.mail('Jő', options=['BODY=8BITMIME', 'SMTPUTF8']), + (250, b'OK')) + self.assertEqual(smtp.rcpt('János'), (250, b'OK')) + self.assertEqual(smtp.data(m), (250, b'OK')) + self.assertEqual(self.serv.last_mailfrom, 'Jő') + self.assertEqual(self.serv.last_rcpttos, ['János']) + self.assertEqual(self.serv.last_message, m) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_send_message_uses_smtputf8_if_addrs_non_ascii(self): + msg = EmailMessage() + msg['From'] = "Páolo " + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + # XXX I don't know why I need two \n's here, but this is an existing + # bug (if it is one) and not a problem with the new functionality. + msg.set_content("oh là là, know what I mean, know what I mean?\n\n") + # XXX smtpd converts received /r/n to /n, so we can't easily test that + # we are successfully sending /r/n :(. + expected = textwrap.dedent("""\ + From: Páolo + To: Dinsdale + Subject: Nudge nudge, wink, wink \u1F609 + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + MIME-Version: 1.0 + + oh là là, know what I mean, know what I mean? + """) + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + self.assertEqual(smtp.send_message(msg), {}) + self.assertEqual(self.serv.last_mailfrom, 'főo@bar.com') + self.assertEqual(self.serv.last_rcpttos, ['Dinsdale']) + self.assertEqual(self.serv.last_message.decode(), expected) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + +EXPECTED_RESPONSE = encode_base64(b'\0psu\0doesnotexist', eol='') + +class SimSMTPAUTHInitialResponseChannel(SimSMTPChannel): + def smtp_AUTH(self, arg): + # RFC 4954's AUTH command allows for an optional initial-response. + # Not all AUTH methods support this; some require a challenge. AUTH + # PLAIN does those, so test that here. See issue #15014. + args = arg.split() + if args[0].lower() == 'plain': + if len(args) == 2: + # AUTH PLAIN with the response base 64 + # encoded. Hard code the expected response for the test. + if args[1] == EXPECTED_RESPONSE: + self.push('235 Ok') + return + self.push('571 Bad authentication') + +class SimSMTPAUTHInitialResponseServer(SimSMTPServer): + channel_class = SimSMTPAUTHInitialResponseChannel + + +class SMTPAUTHInitialResponseSimTests(unittest.TestCase): + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPAUTHInitialResponseServer( + (HOST, 0), ('nowhere', -1), decode_data=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def testAUTH_PLAIN_initial_response_login(self): + self.serv.add_feature('AUTH PLAIN') + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.login('psu', 'doesnotexist') + smtp.close() + + def testAUTH_PLAIN_initial_response_auth(self): + self.serv.add_feature('AUTH PLAIN') + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.user = 'psu' + smtp.password = 'doesnotexist' + code, response = smtp.auth('plain', smtp.auth_plain) + smtp.close() + self.assertEqual(code, 235) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py new file mode 100644 index 0000000000..be25e961f7 --- /dev/null +++ b/Lib/test/test_smtpnet.py @@ -0,0 +1,90 @@ +import unittest +from test import support +from test.support import import_helper +from test.support import socket_helper +import smtplib +import socket + +ssl = import_helper.import_module("ssl") + +support.requires("network") + +def check_ssl_verifiy(host, port): + context = ssl.create_default_context() + with socket.create_connection((host, port)) as sock: + try: + sock = context.wrap_socket(sock, server_hostname=host) + except Exception: + return False + else: + sock.close() + return True + + +class SmtpTest(unittest.TestCase): + testServer = 'smtp.gmail.com' + remotePort = 587 + + def test_connect_starttls(self): + support.get_attribute(smtplib, 'SMTP_SSL') + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP(self.testServer, self.remotePort) + try: + server.starttls(context=context) + except smtplib.SMTPException as e: + if e.args[0] == 'STARTTLS extension not supported by server.': + unittest.skip(e.args[0]) + else: + raise + server.ehlo() + server.quit() + + +class SmtpSSLTest(unittest.TestCase): + testServer = 'smtp.gmail.com' + remotePort = 465 + + def test_connect(self): + support.get_attribute(smtplib, 'SMTP_SSL') + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort) + server.ehlo() + server.quit() + + def test_connect_default_port(self): + support.get_attribute(smtplib, 'SMTP_SSL') + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer) + server.ehlo() + server.quit() + + @support.requires_resource('walltime') + def test_connect_using_sslcontext(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + support.get_attribute(smtplib, 'SMTP_SSL') + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) + server.ehlo() + server.quit() + + def test_connect_using_sslcontext_verified(self): + with socket_helper.transient_internet(self.testServer): + can_verify = check_ssl_verifiy(self.testServer, self.remotePort) + if not can_verify: + self.skipTest("SSL certificate can't be verified") + + support.get_attribute(smtplib, 'SMTP_SSL') + context = ssl.create_default_context() + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) + server.ehlo() + server.quit() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/Lib/test/test_tokenize.py b/Lib/test/test_tokenize.py index e2d2f89454..44ef4e2416 100644 --- a/Lib/test/test_tokenize.py +++ b/Lib/test/test_tokenize.py @@ -1237,7 +1237,6 @@ def test_utf8_normalization(self): found, consumed_lines = detect_encoding(rl) self.assertEqual(found, "utf-8") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_short_files(self): readline = self.get_readline((b'print(something)\n',)) encoding, consumed_lines = detect_encoding(readline) @@ -1316,7 +1315,6 @@ def readline(self): ins = Bunk(lines, path) detect_encoding(ins.readline) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_open_error(self): # Issue #23840: open() must close the binary file on error m = BytesIO(b'#coding:xxx') diff --git a/Lib/test/test_unary.py b/Lib/test/test_unary.py index c3c17cc9f6..a45fbf6bd6 100644 --- a/Lib/test/test_unary.py +++ b/Lib/test/test_unary.py @@ -8,7 +8,6 @@ def test_negative(self): self.assertTrue(-2 == 0 - 2) self.assertEqual(-0, 0) self.assertEqual(--2, 2) - self.assertTrue(-2 == 0 - 2) self.assertTrue(-2.0 == 0 - 2.0) self.assertTrue(-2j == 0 - 2j) @@ -16,15 +15,13 @@ def test_positive(self): self.assertEqual(+2, 2) self.assertEqual(+0, 0) self.assertEqual(++2, 2) - self.assertEqual(+2, 2) self.assertEqual(+2.0, 2.0) self.assertEqual(+2j, 2j) def test_invert(self): - self.assertTrue(-2 == 0 - 2) - self.assertEqual(-0, 0) - self.assertEqual(--2, 2) - self.assertTrue(-2 == 0 - 2) + self.assertTrue(~2 == -(2+1)) + self.assertEqual(~0, -1) + self.assertEqual(~~2, 2) def test_no_overflow(self): nines = "9" * 32 diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py index 686131be74..0a75457ad8 100644 --- a/Lib/test/test_zlib.py +++ b/Lib/test/test_zlib.py @@ -20,18 +20,18 @@ 'requires Decompress.copy()') -# def _zlib_runtime_version_tuple(zlib_version=zlib.ZLIB_RUNTIME_VERSION): -# # Register "1.2.3" as "1.2.3.0" -# # or "1.2.0-linux","1.2.0.f","1.2.0.f-linux" -# v = zlib_version.split('-', 1)[0].split('.') -# if len(v) < 4: -# v.append('0') -# elif not v[-1].isnumeric(): -# v[-1] = '0' -# return tuple(map(int, v)) -# -# -# ZLIB_RUNTIME_VERSION_TUPLE = _zlib_runtime_version_tuple() +def _zlib_runtime_version_tuple(zlib_version=zlib.ZLIB_RUNTIME_VERSION): + # Register "1.2.3" as "1.2.3.0" + # or "1.2.0-linux","1.2.0.f","1.2.0.f-linux" + v = zlib_version.split('-', 1)[0].split('.') + if len(v) < 4: + v.append('0') + elif not v[-1].isnumeric(): + v[-1] = '0' + return tuple(map(int, v)) + + +ZLIB_RUNTIME_VERSION_TUPLE = _zlib_runtime_version_tuple() # bpo-46623: When a hardware accelerator is used (currently only on s390x), @@ -66,8 +66,6 @@ class VersionTestCase(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_library_version(self): # Test that the major version of the actual library in use matches the # major version that we were compiled against. We can't guarantee that @@ -282,8 +280,6 @@ def test_64bit_compress(self, size): class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure # Test compression object def test_pair(self): # straightforward compress/decompress objects @@ -307,8 +303,6 @@ def test_pair(self): self.assertIsInstance(dco.unconsumed_tail, bytes) self.assertIsInstance(dco.unused_data, bytes) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_keywords(self): level = 2 method = zlib.DEFLATED @@ -466,8 +460,6 @@ def test_decompressmaxlen(self, flush=False): def test_decompressmaxlenflush(self): self.test_decompressmaxlen(flush=True) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_maxlenmisc(self): # Misc tests of max_length dco = zlib.decompressobj() @@ -498,7 +490,7 @@ def test_clear_unconsumed_tail(self): ddata += dco.decompress(dco.unconsumed_tail) self.assertEqual(dco.unconsumed_tail, b"") - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON: Z_BLOCK support in flate2 @unittest.expectedFailure def test_flushes(self): # Test flush() with the various options, using all the @@ -560,8 +552,6 @@ def test_empty_flush(self): dco = zlib.decompressobj() self.assertEqual(dco.flush(), b"") # Returns nothing - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dictionary(self): h = HAMLET_SCENE # Build a simulated dictionary out of the words in HAMLET. @@ -578,8 +568,6 @@ def test_dictionary(self): dco = zlib.decompressobj() self.assertRaises(zlib.error, dco.decompress, cd) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dictionary_streaming(self): # This simulates the reuse of a compressor object for compressing # several separate data streams. @@ -652,8 +640,6 @@ def test_decompress_unused_data(self): self.assertEqual(dco.unconsumed_tail, b'') self.assertEqual(dco.unused_data, remainder) - # TODO: RUSTPYTHON - @unittest.expectedFailure # issue27164 def test_decompress_raw_with_dictionary(self): zdict = b'abcdefghijklmnopqrstuvwxyz' @@ -829,7 +815,7 @@ def test_large_unconsumed_tail(self, size): finally: comp = uncomp = data = None - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON: wbits=0 support in flate2 @unittest.expectedFailure def test_wbits(self): # wbits=0 only supported since zlib v1.2.3.5 @@ -997,8 +983,6 @@ def testDecompressUnusedData(self): self.assertEqual(text, self.TEXT) self.assertEqual(zlibd.unused_data, unused_data) - # TODO: RUSTPYTHON - @unittest.expectedFailure def testEOFError(self): zlibd = zlib._ZlibDecompressor() text = zlibd.decompress(self.DATA) @@ -1029,8 +1013,6 @@ def testPickle(self): with self.assertRaises(TypeError): pickle.dumps(zlib._ZlibDecompressor(), proto) - # TODO: RUSTPYTHON - @unittest.expectedFailure def testDecompressorChunksMaxsize(self): zlibd = zlib._ZlibDecompressor() max_length = 100 @@ -1062,8 +1044,6 @@ def testDecompressorChunksMaxsize(self): self.assertEqual(out, self.BIG_TEXT) self.assertEqual(zlibd.unused_data, b"") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decompressor_inputbuf_1(self): # Test reusing input buffer after moving existing # contents to beginning @@ -1086,8 +1066,6 @@ def test_decompressor_inputbuf_1(self): out.append(zlibd.decompress(self.DATA[105:])) self.assertEqual(b''.join(out), self.TEXT) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decompressor_inputbuf_2(self): # Test reusing input buffer by appending data at the # end right away @@ -1109,8 +1087,6 @@ def test_decompressor_inputbuf_2(self): out.append(zlibd.decompress(self.DATA[300:])) self.assertEqual(b''.join(out), self.TEXT) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decompressor_inputbuf_3(self): # Test reusing input buffer after extending it diff --git a/README.md b/README.md index 5b5b16d566..38e4d8fa8c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # [RustPython](https://rustpython.github.io/) -A Python-3 (CPython >= 3.12.0) Interpreter written in Rust :snake: :scream: +A Python-3 (CPython >= 3.13.0) Interpreter written in Rust :snake: :scream: :metal:. [![Build Status](https://github.com/RustPython/RustPython/workflows/CI/badge.svg)](https://github.com/RustPython/RustPython/actions?query=workflow%3ACI) diff --git a/benches/execution.rs b/benches/execution.rs index 6ec9b89f2a..57529bbb3a 100644 --- a/benches/execution.rs +++ b/benches/execution.rs @@ -1,11 +1,11 @@ use criterion::measurement::WallTime; use criterion::{ - black_box, criterion_group, criterion_main, Bencher, BenchmarkGroup, BenchmarkId, Criterion, - Throughput, + Bencher, BenchmarkGroup, BenchmarkId, Criterion, Throughput, black_box, criterion_group, + criterion_main, }; use rustpython_compiler::Mode; -use rustpython_parser::ast; use rustpython_parser::Parse; +use rustpython_parser::ast; use rustpython_vm::{Interpreter, PyResult, Settings}; use std::collections::HashMap; use std::path::Path; diff --git a/benches/microbenchmarks.rs b/benches/microbenchmarks.rs index b742b959b7..6f41f00d6c 100644 --- a/benches/microbenchmarks.rs +++ b/benches/microbenchmarks.rs @@ -1,6 +1,6 @@ use criterion::{ - criterion_group, criterion_main, measurement::WallTime, BatchSize, BenchmarkGroup, BenchmarkId, - Criterion, Throughput, + BatchSize, BenchmarkGroup, BenchmarkId, Criterion, Throughput, criterion_group, criterion_main, + measurement::WallTime, }; use pyo3::types::PyAnyMethods; use rustpython_compiler::Mode; diff --git a/common/src/boxvec.rs b/common/src/boxvec.rs index 9b73fa0103..3b1f7e90a0 100644 --- a/common/src/boxvec.rs +++ b/common/src/boxvec.rs @@ -87,13 +87,16 @@ impl BoxVec { pub unsafe fn push_unchecked(&mut self, element: T) { let len = self.len(); debug_assert!(len < self.capacity()); - ptr::write(self.get_unchecked_ptr(len), element); - self.set_len(len + 1); + // SAFETY: len < capacity + unsafe { + ptr::write(self.get_unchecked_ptr(len), element); + self.set_len(len + 1); + } } /// Get pointer to where element at `index` would be unsafe fn get_unchecked_ptr(&mut self, index: usize) -> *mut T { - self.xs.as_mut_ptr().add(index).cast() + unsafe { self.xs.as_mut_ptr().add(index).cast() } } pub fn insert(&mut self, index: usize, element: T) { @@ -568,7 +571,7 @@ unsafe fn raw_ptr_add(ptr: *mut T, offset: usize) -> *mut T { // Special case for ZST (ptr as usize).wrapping_add(offset) as _ } else { - ptr.add(offset) + unsafe { ptr.add(offset) } } } @@ -576,7 +579,7 @@ unsafe fn raw_ptr_write(ptr: *mut T, value: T) { if mem::size_of::() == 0 { /* nothing */ } else { - ptr::write(ptr, value) + unsafe { ptr::write(ptr, value) } } } diff --git a/common/src/crt_fd.rs b/common/src/crt_fd.rs index 14f61b8059..64d4df98a5 100644 --- a/common/src/crt_fd.rs +++ b/common/src/crt_fd.rs @@ -6,7 +6,7 @@ use std::{cmp, ffi, io}; #[cfg(windows)] use libc::commit as fsync; #[cfg(windows)] -extern "C" { +unsafe extern "C" { #[link_name = "_chsize_s"] fn ftruncate(fd: i32, len: i64) -> i32; } @@ -74,7 +74,7 @@ impl Fd { #[cfg(windows)] pub fn to_raw_handle(&self) -> io::Result { - extern "C" { + unsafe extern "C" { fn _get_osfhandle(fd: i32) -> libc::intptr_t; } let handle = unsafe { suppress_iph!(_get_osfhandle(self.0)) }; diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 4e0c1de56a..858d3b8c6b 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -42,8 +42,8 @@ struct DecodeError<'a> { /// # Safety /// `v[..valid_up_to]` must be valid utf8 unsafe fn make_decode_err(v: &[u8], valid_up_to: usize, err_len: Option) -> DecodeError<'_> { - let valid_prefix = core::str::from_utf8_unchecked(v.get_unchecked(..valid_up_to)); - let rest = v.get_unchecked(valid_up_to..); + let (valid_prefix, rest) = unsafe { v.split_at_unchecked(valid_up_to) }; + let valid_prefix = unsafe { core::str::from_utf8_unchecked(valid_prefix) }; DecodeError { valid_prefix, rest, diff --git a/common/src/fileutils.rs b/common/src/fileutils.rs index dcb78675d8..7d5ff01942 100644 --- a/common/src/fileutils.rs +++ b/common/src/fileutils.rs @@ -5,7 +5,7 @@ pub use libc::stat as StatStruct; #[cfg(windows)] -pub use windows::{fstat, StatStruct}; +pub use windows::{StatStruct, fstat}; #[cfg(not(windows))] pub fn fstat(fd: libc::c_int) -> std::io::Result { @@ -28,19 +28,19 @@ pub mod windows { use std::ffi::{CString, OsStr, OsString}; use std::os::windows::ffi::OsStrExt; use std::sync::OnceLock; - use windows_sys::core::PCWSTR; use windows_sys::Win32::Foundation::{ - FreeLibrary, SetLastError, BOOL, ERROR_INVALID_HANDLE, ERROR_NOT_SUPPORTED, FILETIME, - HANDLE, INVALID_HANDLE_VALUE, + BOOL, ERROR_INVALID_HANDLE, ERROR_NOT_SUPPORTED, FILETIME, FreeLibrary, HANDLE, + INVALID_HANDLE_VALUE, SetLastError, }; use windows_sys::Win32::Storage::FileSystem::{ - FileBasicInfo, FileIdInfo, GetFileInformationByHandle, GetFileInformationByHandleEx, - GetFileType, BY_HANDLE_FILE_INFORMATION, FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_READONLY, + BY_HANDLE_FILE_INFORMATION, FILE_ATTRIBUTE_DIRECTORY, FILE_ATTRIBUTE_READONLY, FILE_ATTRIBUTE_REPARSE_POINT, FILE_BASIC_INFO, FILE_ID_INFO, FILE_TYPE_CHAR, - FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_UNKNOWN, + FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_UNKNOWN, FileBasicInfo, FileIdInfo, + GetFileInformationByHandle, GetFileInformationByHandleEx, GetFileType, }; use windows_sys::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW}; use windows_sys::Win32::System::SystemServices::IO_REPARSE_TAG_SYMLINK; + use windows_sys::core::PCWSTR; pub const S_IFIFO: libc::c_int = 0o010000; pub const S_IFLNK: libc::c_int = 0o120000; @@ -94,7 +94,7 @@ pub mod windows { } } - extern "C" { + unsafe extern "C" { fn _get_osfhandle(fd: i32) -> libc::intptr_t; } diff --git a/common/src/float_ops.rs b/common/src/float_ops.rs index 69ae8833a2..46e2d57067 100644 --- a/common/src/float_ops.rs +++ b/common/src/float_ops.rs @@ -64,11 +64,7 @@ pub fn gt_int(value: f64, other_int: &BigInt) -> bool { } pub fn div(v1: f64, v2: f64) -> Option { - if v2 != 0.0 { - Some(v1 / v2) - } else { - None - } + if v2 != 0.0 { Some(v1 / v2) } else { None } } pub fn mod_(v1: f64, v2: f64) -> Option { @@ -125,10 +121,69 @@ pub fn nextafter(x: f64, y: f64) -> f64 { let b = x.to_bits(); let bits = if (y > x) == (x > 0.0) { b + 1 } else { b - 1 }; let ret = f64::from_bits(bits); - if ret == 0.0 { - ret.copysign(x) + if ret == 0.0 { ret.copysign(x) } else { ret } + } +} + +#[allow(clippy::float_cmp)] +pub fn nextafter_with_steps(x: f64, y: f64, steps: u64) -> f64 { + if x == y { + y + } else if x.is_nan() || y.is_nan() { + f64::NAN + } else if x >= f64::INFINITY { + f64::MAX + } else if x <= f64::NEG_INFINITY { + f64::MIN + } else if x == 0.0 { + f64::from_bits(1).copysign(y) + } else { + if steps == 0 { + return x; + } + + if x.is_nan() { + return x; + } + + if y.is_nan() { + return y; + } + + let sign_bit: u64 = 1 << 63; + + let mut ux = x.to_bits(); + let uy = y.to_bits(); + + let ax = ux & !sign_bit; + let ay = uy & !sign_bit; + + // If signs are different + if ((ux ^ uy) & sign_bit) != 0 { + return if ax + ay <= steps { + f64::from_bits(uy) + } else if ax < steps { + let result = (uy & sign_bit) | (steps - ax); + f64::from_bits(result) + } else { + ux -= steps; + f64::from_bits(ux) + }; + } + + // If signs are the same + if ax > ay { + if ax - ay >= steps { + ux -= steps; + f64::from_bits(ux) + } else { + f64::from_bits(uy) + } + } else if ay - ax >= steps { + ux += steps; + f64::from_bits(ux) } else { - ret + f64::from_bits(uy) } } } diff --git a/common/src/hash.rs b/common/src/hash.rs index 4e87eff799..bbc30b1fe1 100644 --- a/common/src/hash.rs +++ b/common/src/hash.rs @@ -37,11 +37,11 @@ impl BuildHasher for HashSecret { } } -impl rand::distributions::Distribution for rand::distributions::Standard { +impl rand::distr::Distribution for rand::distr::StandardUniform { fn sample(&self, rng: &mut R) -> HashSecret { HashSecret { - k0: rng.gen(), - k1: rng.gen(), + k0: rng.random(), + k1: rng.random(), } } } @@ -114,7 +114,7 @@ pub fn hash_float(value: f64) -> Option { let mut e = frexp.1; let mut x: PyUHash = 0; while m != 0.0 { - x = ((x << 28) & MODULUS) | x >> (BITS - 28); + x = ((x << 28) & MODULUS) | (x >> (BITS - 28)); m *= 268_435_456.0; // 2**28 e -= 28; let y = m as PyUHash; // pull out integer part @@ -132,7 +132,7 @@ pub fn hash_float(value: f64) -> Option { } else { BITS32 - 1 - ((-1 - e) % BITS32) }; - x = ((x << e) & MODULUS) | x >> (BITS32 - e); + x = ((x << e) & MODULUS) | (x >> (BITS32 - e)); Some(fix_sentinel(x as PyHash * value.signum() as PyHash)) } @@ -150,11 +150,7 @@ pub fn hash_bigint(value: &BigInt) -> PyHash { #[inline(always)] pub fn fix_sentinel(x: PyHash) -> PyHash { - if x == SENTINEL { - -2 - } else { - x - } + if x == SENTINEL { -2 } else { x } } #[inline] diff --git a/common/src/int.rs b/common/src/int.rs index ca449ac708..00b5231dff 100644 --- a/common/src/int.rs +++ b/common/src/int.rs @@ -50,7 +50,7 @@ pub fn bytes_to_int(lit: &[u8], mut base: u32) -> Option { base = parsed; true } else { - if let [_first, ref others @ .., last] = lit { + if let [_first, others @ .., last] = lit { let is_zero = others.iter().all(|&c| c == b'0' || c == b'_') && *last == b'0'; if !is_zero { diff --git a/common/src/linked_list.rs b/common/src/linked_list.rs index 3040bab0b9..7f55d727fb 100644 --- a/common/src/linked_list.rs +++ b/common/src/linked_list.rs @@ -208,37 +208,39 @@ impl LinkedList { /// The caller **must** ensure that `node` is currently contained by /// `self` or not contained by any other list. pub unsafe fn remove(&mut self, node: NonNull) -> Option { - if let Some(prev) = L::pointers(node).as_ref().get_prev() { - debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); - L::pointers(prev) - .as_mut() - .set_next(L::pointers(node).as_ref().get_next()); - } else { - if self.head != Some(node) { - return None; + unsafe { + if let Some(prev) = L::pointers(node).as_ref().get_prev() { + debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); + L::pointers(prev) + .as_mut() + .set_next(L::pointers(node).as_ref().get_next()); + } else { + if self.head != Some(node) { + return None; + } + + self.head = L::pointers(node).as_ref().get_next(); } - self.head = L::pointers(node).as_ref().get_next(); - } + if let Some(next) = L::pointers(node).as_ref().get_next() { + debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); + L::pointers(next) + .as_mut() + .set_prev(L::pointers(node).as_ref().get_prev()); + } else { + // // This might be the last item in the list + // if self.tail != Some(node) { + // return None; + // } + + // self.tail = L::pointers(node).as_ref().get_prev(); + } - if let Some(next) = L::pointers(node).as_ref().get_next() { - debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); - L::pointers(next) - .as_mut() - .set_prev(L::pointers(node).as_ref().get_prev()); - } else { - // // This might be the last item in the list - // if self.tail != Some(node) { - // return None; - // } + L::pointers(node).as_mut().set_next(None); + L::pointers(node).as_mut().set_prev(None); - // self.tail = L::pointers(node).as_ref().get_prev(); + Some(L::from_raw(node)) } - - L::pointers(node).as_mut().set_next(None); - L::pointers(node).as_mut().set_prev(None); - - Some(L::from_raw(node)) } // pub fn last(&self) -> Option<&L::Target> { diff --git a/common/src/lock/cell_lock.rs b/common/src/lock/cell_lock.rs index b10101f269..1edd622a20 100644 --- a/common/src/lock/cell_lock.rs +++ b/common/src/lock/cell_lock.rs @@ -140,12 +140,12 @@ unsafe impl RawRwLockUpgrade for RawCellRwLock { #[inline] unsafe fn unlock_upgradable(&self) { - self.unlock_shared() + unsafe { self.unlock_shared() } } #[inline] unsafe fn upgrade(&self) { - if !self.try_upgrade() { + if !unsafe { self.try_upgrade() } { deadlock("upgrade ", "RwLock") } } diff --git a/common/src/lock/thread_mutex.rs b/common/src/lock/thread_mutex.rs index ba36898780..35b0b9ac6d 100644 --- a/common/src/lock/thread_mutex.rs +++ b/common/src/lock/thread_mutex.rs @@ -65,7 +65,7 @@ impl RawThreadMutex { /// This method may only be called if the mutex is held by the current thread. pub unsafe fn unlock(&self) { self.owner.store(0, Ordering::Relaxed); - self.mutex.unlock(); + unsafe { self.mutex.unlock() }; } } diff --git a/common/src/macros.rs b/common/src/macros.rs index 318ab06986..08d00e592d 100644 --- a/common/src/macros.rs +++ b/common/src/macros.rs @@ -41,7 +41,7 @@ pub mod __macro_private { libc::uintptr_t, ); #[cfg(target_env = "msvc")] - extern "C" { + unsafe extern "C" { pub fn _set_thread_local_invalid_parameter_handler( pNew: InvalidParamHandler, ) -> InvalidParamHandler; diff --git a/common/src/os.rs b/common/src/os.rs index c16ec014c5..8a832270bc 100644 --- a/common/src/os.rs +++ b/common/src/os.rs @@ -23,7 +23,7 @@ pub fn last_os_error() -> io::Error { let err = io::Error::last_os_error(); // FIXME: probably not ideal, we need a bigger dichotomy between GetLastError and errno if err.raw_os_error() == Some(0) { - extern "C" { + unsafe extern "C" { fn _get_errno(pValue: *mut i32) -> i32; } let mut errno = 0; @@ -44,7 +44,7 @@ pub fn last_os_error() -> io::Error { pub fn last_posix_errno() -> i32 { let err = io::Error::last_os_error(); if err.raw_os_error() == Some(0) { - extern "C" { + unsafe extern "C" { fn _get_errno(pValue: *mut i32) -> i32; } let mut errno = 0; diff --git a/common/src/str.rs b/common/src/str.rs index 3f9bf583b8..b4f7a1a636 100644 --- a/common/src/str.rs +++ b/common/src/str.rs @@ -241,11 +241,7 @@ pub mod levenshtein { if b.is_ascii_uppercase() { b += b'a' - b'A'; } - if a == b { - CASE_COST - } else { - MOVE_COST - } + if a == b { CASE_COST } else { MOVE_COST } } pub fn levenshtein_distance(a: &str, b: &str, max_cost: usize) -> usize { @@ -322,6 +318,37 @@ pub mod levenshtein { } } +/// Replace all tabs in a string with spaces, using the given tab size. +pub fn expandtabs(input: &str, tab_size: usize) -> String { + let tab_stop = tab_size; + let mut expanded_str = String::with_capacity(input.len()); + let mut tab_size = tab_stop; + let mut col_count = 0usize; + for ch in input.chars() { + match ch { + '\t' => { + let num_spaces = tab_size - col_count; + col_count += num_spaces; + let expand = " ".repeat(num_spaces); + expanded_str.push_str(&expand); + } + '\r' | '\n' => { + expanded_str.push(ch); + col_count = 0; + tab_size = 0; + } + _ => { + expanded_str.push(ch); + col_count += 1; + } + } + if col_count >= tab_size { + tab_size += tab_stop; + } + } + expanded_str +} + /// Creates an [`AsciiStr`][ascii::AsciiStr] from a string literal, throwing a compile error if the /// literal isn't actually ascii. /// diff --git a/compiler/codegen/Cargo.toml b/compiler/codegen/Cargo.toml index 0fe950be71..0817a95894 100644 --- a/compiler/codegen/Cargo.toml +++ b/compiler/codegen/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true [dependencies] rustpython-ast = { workspace = true, features=["unparse", "constant-optimization"] } +rustpython-common = { workspace = true } rustpython-parser-core = { workspace = true } rustpython-compiler-core = { workspace = true } diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index eeb4b7dec6..6b1d15c720 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -8,18 +8,21 @@ #![deny(clippy::cast_possible_truncation)] use crate::{ + IndexSet, error::{CodegenError, CodegenErrorType}, ir, symboltable::{self, SymbolFlags, SymbolScope, SymbolTable}, - IndexSet, }; use itertools::Itertools; use num_complex::Complex64; use num_traits::ToPrimitive; use rustpython_ast::located::{self as located_ast, Located}; use rustpython_compiler_core::{ - bytecode::{self, Arg as OpArgMarker, CodeObject, ConstantData, Instruction, OpArg, OpArgType}, Mode, + bytecode::{ + self, Arg as OpArgMarker, CodeObject, ComparisonOperator, ConstantData, Instruction, OpArg, + OpArgType, + }, }; use rustpython_parser_core::source_code::{LineNumber, SourceLocation}; use std::borrow::Cow; @@ -211,6 +214,12 @@ macro_rules! emit { }; } +struct PatternContext { + current_block: usize, + blocks: Vec, + allow_irrefutable: bool, +} + impl Compiler { fn new(opts: CompileOpts, source_path: String, code_name: String) -> Self { let module_code = ir::CodeInfo { @@ -966,7 +975,7 @@ impl Compiler { } } located_ast::Expr::BinOp(_) | located_ast::Expr::UnaryOp(_) => { - return Err(self.error(CodegenErrorType::Delete("expression"))) + return Err(self.error(CodegenErrorType::Delete("expression"))); } _ => return Err(self.error(CodegenErrorType::Delete(expression.python_name()))), } @@ -1086,8 +1095,27 @@ impl Compiler { self.store_name(name.as_ref())?; } } - located_ast::TypeParam::ParamSpec(_) => todo!(), - located_ast::TypeParam::TypeVarTuple(_) => todo!(), + located_ast::TypeParam::ParamSpec(located_ast::TypeParamParamSpec { + name, .. + }) => { + self.emit_load_const(ConstantData::Str { + value: name.to_string(), + }); + emit!(self, Instruction::ParamSpec); + emit!(self, Instruction::Duplicate); + self.store_name(name.as_ref())?; + } + located_ast::TypeParam::TypeVarTuple(located_ast::TypeParamTypeVarTuple { + name, + .. + }) => { + self.emit_load_const(ConstantData::Str { + value: name.to_string(), + }); + emit!(self, Instruction::TypeVarTuple); + emit!(self, Instruction::Duplicate); + self.store_name(name.as_ref())?; + } }; } emit!( @@ -1185,7 +1213,7 @@ impl Compiler { if !finalbody.is_empty() { emit!(self, Instruction::PopBlock); // pop excepthandler block - // We enter the finally block, without exception. + // We enter the finally block, without exception. emit!(self, Instruction::EnterFinally); } @@ -1755,14 +1783,152 @@ impl Compiler { Ok(()) } + fn compile_pattern_value( + &mut self, + value: &located_ast::PatternMatchValue, + _pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + self.compile_expression(&value.value)?; + emit!( + self, + Instruction::CompareOperation { + op: ComparisonOperator::Equal + } + ); + Ok(()) + } + + fn compile_pattern_as( + &mut self, + as_pattern: &located_ast::PatternMatchAs, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + if as_pattern.pattern.is_none() && !pattern_context.allow_irrefutable { + // TODO: better error message + if let Some(_name) = as_pattern.name.as_ref() { + return Err( + self.error_loc(CodegenErrorType::InvalidMatchCase, as_pattern.location()) + ); + } + return Err(self.error_loc(CodegenErrorType::InvalidMatchCase, as_pattern.location())); + } + // Need to make a copy for (possibly) storing later: + emit!(self, Instruction::Duplicate); + if let Some(pattern) = &as_pattern.pattern { + self.compile_pattern_inner(pattern, pattern_context)?; + } + if let Some(name) = as_pattern.name.as_ref() { + self.store_name(name.as_str())?; + } else { + emit!(self, Instruction::Pop); + } + Ok(()) + } + + fn compile_pattern_inner( + &mut self, + pattern_type: &located_ast::Pattern, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + match &pattern_type { + located_ast::Pattern::MatchValue(value) => { + self.compile_pattern_value(value, pattern_context) + } + located_ast::Pattern::MatchAs(as_pattern) => { + self.compile_pattern_as(as_pattern, pattern_context) + } + _ => { + eprintln!("not implemented pattern type: {pattern_type:?}"); + Err(self.error(CodegenErrorType::NotImplementedYet)) + } + } + } + + fn compile_pattern( + &mut self, + pattern_type: &located_ast::Pattern, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + self.compile_pattern_inner(pattern_type, pattern_context)?; + emit!( + self, + Instruction::JumpIfFalse { + target: pattern_context.blocks[pattern_context.current_block + 1] + } + ); + Ok(()) + } + + fn compile_match_inner( + &mut self, + subject: &located_ast::Expr, + cases: &[located_ast::MatchCase], + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + self.compile_expression(subject)?; + pattern_context.blocks = std::iter::repeat_with(|| self.new_block()) + .take(cases.len() + 1) + .collect::>(); + let end_block = *pattern_context.blocks.last().unwrap(); + + let _match_case_type = cases.last().expect("cases is not empty"); + // TODO: get proper check for default case + // let has_default = match_case_type.pattern.is_match_as() && 1 < cases.len(); + let has_default = false; + for i in 0..cases.len() - (has_default as usize) { + self.switch_to_block(pattern_context.blocks[i]); + pattern_context.current_block = i; + pattern_context.allow_irrefutable = cases[i].guard.is_some() || i == cases.len() - 1; + let m = &cases[i]; + // Only copy the subject if we're *not* on the last case: + if i != cases.len() - has_default as usize - 1 { + emit!(self, Instruction::Duplicate); + } + self.compile_pattern(&m.pattern, pattern_context)?; + self.compile_statements(&m.body)?; + emit!(self, Instruction::Jump { target: end_block }); + } + // TODO: below code is not called and does not work + if has_default { + // A trailing "case _" is common, and lets us save a bit of redundant + // pushing and popping in the loop above: + let m = &cases.last().unwrap(); + self.switch_to_block(*pattern_context.blocks.last().unwrap()); + if cases.len() == 1 { + // No matches. Done with the subject: + emit!(self, Instruction::Pop); + } else { + // Show line coverage for default case (it doesn't create bytecode) + // emit!(self, Instruction::Nop); + } + self.compile_statements(&m.body)?; + } + + self.switch_to_block(end_block); + + let code = self.current_code_info(); + pattern_context + .blocks + .iter() + .zip(pattern_context.blocks.iter().skip(1)) + .for_each(|(a, b)| { + code.blocks[a.0 as usize].next = *b; + }); + Ok(()) + } + fn compile_match( &mut self, subject: &located_ast::Expr, cases: &[located_ast::MatchCase], ) -> CompileResult<()> { - eprintln!("match subject: {subject:?}"); - eprintln!("match cases: {cases:?}"); - Err(self.error(CodegenErrorType::NotImplementedYet)) + let mut pattern_context = PatternContext { + current_block: usize::MAX, + blocks: Vec::new(), + allow_irrefutable: false, + }; + self.compile_match_inner(subject, cases, &mut pattern_context)?; + Ok(()) } fn compile_chained_comparison( @@ -2557,7 +2723,7 @@ impl Compiler { fn compile_keywords(&mut self, keywords: &[located_ast::Keyword]) -> CompileResult<()> { let mut size = 0; - let groupby = keywords.iter().group_by(|e| e.arg.is_none()); + let groupby = keywords.iter().chunk_by(|e| e.arg.is_none()); for (is_unpacking, sub_keywords) in &groupby { if is_unpacking { for keyword in sub_keywords { @@ -2720,7 +2886,7 @@ impl Compiler { (false, element) } }) - .group_by(|(starred, _)| *starred); + .chunk_by(|(starred, _)| *starred); for (starred, run) in &groups { let mut run_size = 0; @@ -2958,7 +3124,9 @@ impl Compiler { | "with_statement" | "print_function" | "unicode_literals" | "generator_stop" => {} "annotations" => self.future_annotations = true, other => { - return Err(self.error(CodegenErrorType::InvalidFutureFeature(other.to_owned()))) + return Err( + self.error(CodegenErrorType::InvalidFutureFeature(other.to_owned())) + ); } } } @@ -3036,16 +3204,17 @@ impl Compiler { fn switch_to_block(&mut self, block: ir::BlockIdx) { let code = self.current_code_info(); let prev = code.current_block; + assert_ne!(prev, block, "recursive switching {prev:?} -> {block:?}"); assert_eq!( code.blocks[block].next, ir::BlockIdx::NULL, - "switching to completed block" + "switching {prev:?} -> {block:?} to completed block" ); let prev_block = &mut code.blocks[prev.0 as usize]; assert_eq!( prev_block.next.0, u32::MAX, - "switching from block that's already got a next" + "switching {prev:?} -> {block:?} from block that's already got a next" ); prev_block.next = block; code.current_block = block; @@ -3135,13 +3304,13 @@ impl Compiler { elt, generators, .. }) => { Self::contains_await(elt) - || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + || generators.iter().any(|jen| Self::contains_await(&jen.iter)) } Expr::SetComp(located_ast::ExprSetComp { elt, generators, .. }) => { Self::contains_await(elt) - || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + || generators.iter().any(|jen| Self::contains_await(&jen.iter)) } Expr::DictComp(located_ast::ExprDictComp { key, @@ -3151,13 +3320,13 @@ impl Compiler { }) => { Self::contains_await(key) || Self::contains_await(value) - || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + || generators.iter().any(|jen| Self::contains_await(&jen.iter)) } Expr::GeneratorExp(located_ast::ExprGeneratorExp { elt, generators, .. }) => { Self::contains_await(elt) - || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + || generators.iter().any(|jen| Self::contains_await(&jen.iter)) } Expr::Starred(expr) => Self::contains_await(&expr.value), Expr::IfExp(located_ast::ExprIfExp { @@ -3201,17 +3370,51 @@ impl EmitArg for ir::BlockIdx { } } +/// Strips leading whitespace from a docstring. +/// +/// The code has been ported from `_PyCompile_CleanDoc` in cpython. +/// `inspect.cleandoc` is also a good reference, but has a few incompatibilities. +fn clean_doc(doc: &str) -> String { + let doc = rustpython_common::str::expandtabs(doc, 8); + // First pass: find minimum indentation of any non-blank lines + // after first line. + let margin = doc + .lines() + // Find the non-blank lines + .filter(|line| !line.trim().is_empty()) + // get the one with the least indentation + .map(|line| line.chars().take_while(|c| c == &' ').count()) + .min(); + if let Some(margin) = margin { + let mut cleaned = String::with_capacity(doc.len()); + // copy first line without leading whitespace + if let Some(first_line) = doc.lines().next() { + cleaned.push_str(first_line.trim_start()); + } + // copy subsequent lines without margin. + for line in doc.split('\n').skip(1) { + cleaned.push('\n'); + let cleaned_line = line.chars().skip(margin).collect::(); + cleaned.push_str(&cleaned_line); + } + + cleaned + } else { + doc.to_owned() + } +} + fn split_doc<'a>( body: &'a [located_ast::Stmt], opts: &CompileOpts, ) -> (Option, &'a [located_ast::Stmt]) { if let Some((located_ast::Stmt::Expr(expr), body_rest)) = body.split_first() { if let Some(doc) = try_get_constant_string(std::slice::from_ref(&expr.value)) { - if opts.optimize < 2 { - return (Some(doc), body_rest); + return if opts.optimize < 2 { + (Some(clean_doc(&doc)), body_rest) } else { - return (None, body_rest); - } + (None, body_rest) + }; } } (None, body) @@ -3276,8 +3479,8 @@ impl ToU32 for usize { #[cfg(test)] mod tests { use super::*; - use rustpython_parser::ast::Suite; use rustpython_parser::Parse; + use rustpython_parser::ast::Suite; use rustpython_parser_core::source_code::LinearLocator; fn compile_exec(source: &str) -> CodeObject { diff --git a/compiler/codegen/src/error.rs b/compiler/codegen/src/error.rs index 017f735105..27333992df 100644 --- a/compiler/codegen/src/error.rs +++ b/compiler/codegen/src/error.rs @@ -30,6 +30,8 @@ pub enum CodegenErrorType { TooManyStarUnpack, EmptyWithItems, EmptyWithBody, + DuplicateStore(String), + InvalidMatchCase, NotImplementedYet, // RustPython marker for unimplemented features } @@ -75,6 +77,12 @@ impl fmt::Display for CodegenErrorType { EmptyWithBody => { write!(f, "empty body on With") } + DuplicateStore(s) => { + write!(f, "duplicate store {s}") + } + InvalidMatchCase => { + write!(f, "invalid match case") + } NotImplementedYet => { write!(f, "RustPython does not implement this feature yet") } diff --git a/compiler/codegen/src/ir.rs b/compiler/codegen/src/ir.rs index 9f1a86e51d..08e68d283a 100644 --- a/compiler/codegen/src/ir.rs +++ b/compiler/codegen/src/ir.rs @@ -199,11 +199,7 @@ impl CodeInfo { }) .collect::>(); - if found_cellarg { - Some(cell2arg) - } else { - None - } + if found_cellarg { Some(cell2arg) } else { None } } fn dce(&mut self) { diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index 8522c82037..7db5c4edbe 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -8,8 +8,8 @@ Inspirational file: https://github.com/python/cpython/blob/main/Python/symtable. */ use crate::{ - error::{CodegenError, CodegenErrorType}, IndexMap, + error::{CodegenError, CodegenErrorType}, }; use bitflags::bitflags; use rustpython_ast::{self as ast, located::Located}; @@ -505,7 +505,10 @@ impl SymbolTableAnalyzer { // check if assignee is an iterator in top scope if parent_symbol.flags.contains(SymbolFlags::ITER) { return Err(SymbolTableError { - error: format!("assignment expression cannot rebind comprehension iteration variable {}", symbol.name), + error: format!( + "assignment expression cannot rebind comprehension iteration variable {}", + symbol.name + ), location: None, }); } @@ -886,11 +889,13 @@ impl SymbolTableBuilder { self.scan_statements(orelse)?; self.scan_statements(finalbody)?; } - Stmt::Match(StmtMatch { subject, .. }) => { - return Err(SymbolTableError { - error: "match expression is not implemented yet".to_owned(), - location: Some(subject.location()), - }); + Stmt::Match(StmtMatch { subject, cases, .. }) => { + self.scan_expression(subject, ExpressionContext::Load)?; + for case in cases { + // TODO: below + // self.scan_pattern(&case.pattern, ExpressionContext::Load)?; + self.scan_statements(&case.body)?; + } } Stmt::Raise(StmtRaise { exc, cause, .. }) => { if let Some(expression) = exc { @@ -1255,8 +1260,18 @@ impl SymbolTableBuilder { self.scan_expression(binding, ExpressionContext::Load)?; } } - ast::located::TypeParam::ParamSpec(_) => todo!(), - ast::located::TypeParam::TypeVarTuple(_) => todo!(), + ast::located::TypeParam::ParamSpec(ast::TypeParamParamSpec { + name, + range: param_spec_range, + }) => { + self.register_name(name, SymbolUsage::Assigned, param_spec_range.start)?; + } + ast::located::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { + name, + range: type_var_tuple_range, + }) => { + self.register_name(name, SymbolUsage::Assigned, type_var_tuple_range.start)?; + } } } Ok(()) @@ -1396,7 +1411,7 @@ impl SymbolTableBuilder { return Err(SymbolTableError { error: format!("cannot define nonlocal '{name}' at top level."), location, - }) + }); } _ => { // Ok! diff --git a/compiler/core/Cargo.toml b/compiler/core/Cargo.toml index 3d05fc2734..619ffcf61e 100644 --- a/compiler/core/Cargo.toml +++ b/compiler/core/Cargo.toml @@ -15,7 +15,7 @@ bitflags = { workspace = true } itertools = { workspace = true } malachite-bigint = { workspace = true } num-complex = { workspace = true } -serde = { version = "1.0.133", optional = true, default-features = false, features = ["derive"] } +serde = { version = "1.0.217", optional = true, default-features = false, features = ["derive"] } lz4_flex = "0.11" diff --git a/compiler/core/src/bytecode.rs b/compiler/core/src/bytecode.rs index c8dbc63744..11e49a47db 100644 --- a/compiler/core/src/bytecode.rs +++ b/compiler/core/src/bytecode.rs @@ -197,7 +197,7 @@ impl OpArgState { } #[inline(always)] pub fn extend(&mut self, arg: OpArgByte) -> OpArg { - self.state = self.state << 8 | u32::from(arg.0); + self.state = (self.state << 8) | u32::from(arg.0); OpArg(self.state) } #[inline(always)] @@ -293,10 +293,8 @@ impl Arg { /// # Safety /// T::from_op_arg(self) must succeed pub unsafe fn get_unchecked(self, arg: OpArg) -> T { - match T::from_op_arg(arg.0) { - Some(t) => t, - None => std::hint::unreachable_unchecked(), - } + // SAFETY: requirements forwarded from caller + unsafe { T::from_op_arg(arg.0).unwrap_unchecked() } } } @@ -595,10 +593,12 @@ pub enum Instruction { TypeVarWithBound, TypeVarWithConstraint, TypeAlias, + TypeVarTuple, + ParamSpec, // If you add a new instruction here, be sure to keep LAST_INSTRUCTION updated } // This must be kept up to date to avoid marshaling errors -const LAST_INSTRUCTION: Instruction = Instruction::TypeAlias; +const LAST_INSTRUCTION: Instruction = Instruction::ParamSpec; const _: () = assert!(mem::size_of::() == 1); impl From for u8 { @@ -1291,6 +1291,8 @@ impl Instruction { TypeVarWithBound => -1, TypeVarWithConstraint => -1, TypeAlias => -2, + ParamSpec => 0, + TypeVarTuple => 0, } } @@ -1460,6 +1462,8 @@ impl Instruction { TypeVarWithBound => w!(TypeVarWithBound), TypeVarWithConstraint => w!(TypeVarWithConstraint), TypeAlias => w!(TypeAlias), + ParamSpec => w!(ParamSpec), + TypeVarTuple => w!(TypeVarTuple), } } } diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index c281058555..7d226cd1ce 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -1,9 +1,9 @@ use rustpython_codegen::{compile, symboltable}; -use rustpython_parser::ast::{self as ast, fold::Fold, ConstantOptimizer}; +use rustpython_parser::ast::{self as ast, ConstantOptimizer, fold::Fold}; pub use rustpython_codegen::compile::CompileOpts; -pub use rustpython_compiler_core::{bytecode::CodeObject, Mode}; -pub use rustpython_parser::{source_code::LinearLocator, Parse}; +pub use rustpython_compiler_core::{Mode, bytecode::CodeObject}; +pub use rustpython_parser::{Parse, source_code::LinearLocator}; // these modules are out of repository. re-exporting them here for convenience. pub use rustpython_codegen as codegen; diff --git a/derive-impl/Cargo.toml b/derive-impl/Cargo.toml index a843f5d3c5..debe58106b 100644 --- a/derive-impl/Cargo.toml +++ b/derive-impl/Cargo.toml @@ -18,10 +18,10 @@ once_cell = { workspace = true } syn = { workspace = true, features = ["full", "extra-traits"] } maplit = "1.0.2" -proc-macro2 = "1.0.79" -quote = "1.0.18" -syn-ext = { version = "0.4.0", features = ["full"] } -textwrap = { version = "0.15.0", default-features = false } +proc-macro2 = "1.0.93" +quote = "1.0.38" +syn-ext = { version = "0.5.0", features = ["full"] } +textwrap = { version = "0.16.1", default-features = false } [lints] workspace = true \ No newline at end of file diff --git a/derive-impl/src/compile_bytecode.rs b/derive-impl/src/compile_bytecode.rs index 6b5baef98c..2a57134197 100644 --- a/derive-impl/src/compile_bytecode.rs +++ b/derive-impl/src/compile_bytecode.rs @@ -13,22 +13,20 @@ //! ) //! ``` -use crate::{extract_spans, Diagnostic}; +use crate::Diagnostic; use once_cell::sync::Lazy; use proc_macro2::{Span, TokenStream}; use quote::quote; -use rustpython_compiler_core::{bytecode::CodeObject, frozen, Mode}; +use rustpython_compiler_core::{Mode, bytecode::CodeObject, frozen}; use std::{ collections::HashMap, env, fs, path::{Path, PathBuf}, }; use syn::{ - self, - parse::{Parse, ParseStream, Result as ParseResult}, - parse2, + self, LitByteStr, LitStr, Macro, + parse::{ParseStream, Parser, Result as ParseResult}, spanned::Spanned, - Lit, LitByteStr, LitStr, Macro, Meta, MetaNameValue, Token, }; static CARGO_MANIFEST_DIR: Lazy = Lazy::new(|| { @@ -119,11 +117,13 @@ impl CompilationSource { })?; self.compile_string(&source, mode, module_name, compiler, || rel_path.display()) } - CompilationSourceKind::SourceCode(code) => { - self.compile_string(&textwrap::dedent(code), mode, module_name, compiler, || { - "string literal" - }) - } + CompilationSourceKind::SourceCode(code) => self.compile_string( + &textwrap::dedent(code), + mode, + module_name, + compiler, + || "string literal", + ), CompilationSourceKind::Dir(_) => { unreachable!("Can't use compile_single with directory source") } @@ -233,23 +233,17 @@ impl CompilationSource { } } -/// This is essentially just a comma-separated list of Meta nodes, aka the inside of a MetaList. -struct PyCompileInput { - span: Span, - metas: Vec, -} - -impl PyCompileInput { - fn parse(&self, allow_dir: bool) -> Result { +impl PyCompileArgs { + fn parse(input: TokenStream, allow_dir: bool) -> Result { let mut module_name = None; let mut mode = None; let mut source: Option = None; let mut crate_name = None; - fn assert_source_empty(source: &Option) -> Result<(), Diagnostic> { + fn assert_source_empty(source: &Option) -> Result<(), syn::Error> { if let Some(source) = source { - Err(Diagnostic::spans_error( - source.span, + Err(syn::Error::new( + source.span.0, "Cannot have more than one source", )) } else { @@ -257,59 +251,58 @@ impl PyCompileInput { } } - for meta in &self.metas { - if let Meta::NameValue(name_value) = meta { - let ident = match name_value.path.get_ident() { - Some(ident) => ident, - None => continue, - }; - let check_str = || match &name_value.lit { - Lit::Str(s) => Ok(s), - _ => Err(err_span!(name_value.lit, "{ident} must be a string")), - }; - if ident == "mode" { - let s = check_str()?; - match s.value().parse() { - Ok(mode_val) => mode = Some(mode_val), - Err(e) => bail_span!(s, "{}", e), - } - } else if ident == "module_name" { - module_name = Some(check_str()?.value()) - } else if ident == "source" { - assert_source_empty(&source)?; - let code = check_str()?.value(); - source = Some(CompilationSource { - kind: CompilationSourceKind::SourceCode(code), - span: extract_spans(&name_value).unwrap(), - }); - } else if ident == "file" { - assert_source_empty(&source)?; - let path = check_str()?.value().into(); - source = Some(CompilationSource { - kind: CompilationSourceKind::File(path), - span: extract_spans(&name_value).unwrap(), - }); - } else if ident == "dir" { - if !allow_dir { - bail_span!(ident, "py_compile doesn't accept dir") - } - - assert_source_empty(&source)?; - let path = check_str()?.value().into(); - source = Some(CompilationSource { - kind: CompilationSourceKind::Dir(path), - span: extract_spans(&name_value).unwrap(), - }); - } else if ident == "crate_name" { - let name = check_str()?.parse()?; - crate_name = Some(name); + syn::meta::parser(|meta| { + let ident = meta + .path + .get_ident() + .ok_or_else(|| meta.error("unknown arg"))?; + let check_str = || meta.value()?.call(parse_str); + if ident == "mode" { + let s = check_str()?; + match s.value().parse() { + Ok(mode_val) => mode = Some(mode_val), + Err(e) => bail_span!(s, "{}", e), } + } else if ident == "module_name" { + module_name = Some(check_str()?.value()) + } else if ident == "source" { + assert_source_empty(&source)?; + let code = check_str()?.value(); + source = Some(CompilationSource { + kind: CompilationSourceKind::SourceCode(code), + span: (ident.span(), meta.input.cursor().span()), + }); + } else if ident == "file" { + assert_source_empty(&source)?; + let path = check_str()?.value().into(); + source = Some(CompilationSource { + kind: CompilationSourceKind::File(path), + span: (ident.span(), meta.input.cursor().span()), + }); + } else if ident == "dir" { + if !allow_dir { + bail_span!(ident, "py_compile doesn't accept dir") + } + + assert_source_empty(&source)?; + let path = check_str()?.value().into(); + source = Some(CompilationSource { + kind: CompilationSourceKind::Dir(path), + span: (ident.span(), meta.input.cursor().span()), + }); + } else if ident == "crate_name" { + let name = check_str()?.parse()?; + crate_name = Some(name); + } else { + return Err(meta.error("unknown attr")); } - } + Ok(()) + }) + .parse2(input)?; let source = source.ok_or_else(|| { syn::Error::new( - self.span, + Span::call_site(), "Must have either file or source in py_compile!()/py_freeze!()", ) })?; @@ -323,38 +316,17 @@ impl PyCompileInput { } } -fn parse_meta(input: ParseStream) -> ParseResult { - let path = input.call(syn::Path::parse_mod_style)?; - let eq_token: Token![=] = input.parse()?; +fn parse_str(input: ParseStream) -> ParseResult { let span = input.span(); if input.peek(LitStr) { - Ok(Meta::NameValue(MetaNameValue { - path, - eq_token, - lit: Lit::Str(input.parse()?), - })) + input.parse() } else if let Ok(mac) = input.parse::() { - Ok(Meta::NameValue(MetaNameValue { - path, - eq_token, - lit: Lit::Str(LitStr::new(&mac.tokens.to_string(), mac.span())), - })) + Ok(LitStr::new(&mac.tokens.to_string(), mac.span())) } else { Err(syn::Error::new(span, "Expected string or stringify macro")) } } -impl Parse for PyCompileInput { - fn parse(input: ParseStream) -> ParseResult { - let span = input.cursor().span(); - let metas = input - .parse_terminated::(parse_meta)? - .into_iter() - .collect(); - Ok(PyCompileInput { span, metas }) - } -} - struct PyCompileArgs { source: CompilationSource, mode: Mode, @@ -366,8 +338,7 @@ pub fn impl_py_compile( input: TokenStream, compiler: &dyn Compiler, ) -> Result { - let input: PyCompileInput = parse2(input)?; - let args = input.parse(false)?; + let args = PyCompileArgs::parse(input, false)?; let crate_name = args.crate_name; let code = args @@ -388,8 +359,7 @@ pub fn impl_py_freeze( input: TokenStream, compiler: &dyn Compiler, ) -> Result { - let input: PyCompileInput = parse2(input)?; - let args = input.parse(true)?; + let args = PyCompileArgs::parse(input, true)?; let crate_name = args.crate_name; let code_map = args.source.compile(args.mode, args.module_name, compiler)?; diff --git a/derive-impl/src/from_args.rs b/derive-impl/src/from_args.rs index 7b5b684213..2273046ed4 100644 --- a/derive-impl/src/from_args.rs +++ b/derive-impl/src/from_args.rs @@ -1,9 +1,8 @@ use proc_macro2::TokenStream; -use quote::{quote, ToTokens}; +use quote::{ToTokens, quote}; use syn::ext::IdentExt; -use syn::{ - parse_quote, Attribute, Data, DeriveInput, Expr, Field, Ident, Lit, Meta, NestedMeta, Result, -}; +use syn::meta::ParseNestedMeta; +use syn::{Attribute, Data, DeriveInput, Expr, Field, Ident, Result, Token, parse_quote}; /// The kind of the python parameter, this corresponds to the value of Parameter.kind /// (https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind) @@ -36,84 +35,61 @@ type DefaultValue = Option; impl ArgAttribute { fn from_attribute(attr: &Attribute) -> Option> { - if !attr.path.is_ident("pyarg") { + if !attr.path().is_ident("pyarg") { return None; } let inner = move || { - let Meta::List(list) = attr.parse_meta()? else { - bail_span!(attr, "pyarg must be a list, like #[pyarg(...)]") - }; - let mut iter = list.nested.iter(); - let first_arg = iter.next().ok_or_else(|| { - err_span!(list, "There must be at least one argument to #[pyarg()]") + let mut arg_attr = None; + attr.parse_nested_meta(|meta| { + let Some(arg_attr) = &mut arg_attr else { + let kind = meta + .path + .get_ident() + .and_then(ParameterKind::from_ident) + .ok_or_else(|| { + meta.error( + "The first argument to #[pyarg()] must be the parameter type, \ + either 'positional', 'any', 'named', or 'flatten'.", + ) + })?; + arg_attr = Some(ArgAttribute { + name: None, + kind, + default: None, + }); + return Ok(()); + }; + arg_attr.parse_argument(meta) })?; - let kind = match first_arg { - NestedMeta::Meta(Meta::Path(path)) => { - path.get_ident().and_then(ParameterKind::from_ident) - } - _ => None, - }; - let kind = kind.ok_or_else(|| { - err_span!( - first_arg, - "The first argument to #[pyarg()] must be the parameter type, either \ - 'positional', 'any', 'named', or 'flatten'." - ) - })?; - - let mut attribute = ArgAttribute { - name: None, - kind, - default: None, - }; - - for arg in iter { - attribute.parse_argument(arg)?; - } - - Ok(attribute) + arg_attr + .ok_or_else(|| err_span!(attr, "There must be at least one argument to #[pyarg()]")) }; Some(inner()) } - fn parse_argument(&mut self, arg: &NestedMeta) -> Result<()> { + fn parse_argument(&mut self, meta: ParseNestedMeta<'_>) -> Result<()> { if let ParameterKind::Flatten = self.kind { - bail_span!(arg, "can't put additional arguments on a flatten arg") + return Err(meta.error("can't put additional arguments on a flatten arg")); } - match arg { - NestedMeta::Meta(Meta::Path(path)) => { - if path.is_ident("default") || path.is_ident("optional") { - if self.default.is_none() { - self.default = Some(None); - } - } else { - bail_span!(path, "Unrecognized pyarg attribute"); - } + if meta.path.is_ident("default") && meta.input.peek(Token![=]) { + if matches!(self.default, Some(Some(_))) { + return Err(meta.error("Default already set")); } - NestedMeta::Meta(Meta::NameValue(name_value)) => { - if name_value.path.is_ident("default") { - if matches!(self.default, Some(Some(_))) { - bail_span!(name_value, "Default already set"); - } - - match name_value.lit { - Lit::Str(ref val) => self.default = Some(Some(val.parse()?)), - _ => bail_span!(name_value, "Expected string value for default argument"), - } - } else if name_value.path.is_ident("name") { - if self.name.is_some() { - bail_span!(name_value, "already have a name") - } - - match &name_value.lit { - Lit::Str(val) => self.name = Some(val.value()), - _ => bail_span!(name_value, "Expected string value for name argument"), - } - } else { - bail_span!(name_value, "Unrecognized pyarg attribute"); - } + let val = meta.value()?; + let val = val.parse::()?; + self.default = Some(Some(val.parse()?)) + } else if meta.path.is_ident("default") || meta.path.is_ident("optional") { + if self.default.is_none() { + self.default = Some(None); + } + } else if meta.path.is_ident("name") { + if self.name.is_some() { + return Err(meta.error("already have a name")); } - _ => bail_span!(arg, "Unrecognized pyarg attribute"), + let val = meta.value()?.parse::()?; + self.name = Some(val.value()) + } else { + return Err(meta.error("Unrecognized pyarg attribute")); } Ok(()) diff --git a/derive-impl/src/lib.rs b/derive-impl/src/lib.rs index 35292e7de0..a1f97c96b0 100644 --- a/derive-impl/src/lib.rs +++ b/derive-impl/src/lib.rs @@ -20,11 +20,12 @@ mod pypayload; mod pystructseq; mod pytraverse; -use error::{extract_spans, Diagnostic}; +use error::Diagnostic; use proc_macro2::TokenStream; use quote::ToTokens; use rustpython_doc as doc; -use syn::{AttributeArgs, DeriveInput, Item}; +use syn::{DeriveInput, Item}; +use syn_ext::types::PunctuatedNestedMeta; pub use compile_bytecode::Compiler; @@ -38,7 +39,7 @@ pub fn derive_from_args(input: DeriveInput) -> TokenStream { result_to_tokens(from_args::impl_from_args(input)) } -pub fn pyclass(attr: AttributeArgs, item: Item) -> TokenStream { +pub fn pyclass(attr: PunctuatedNestedMeta, item: Item) -> TokenStream { if matches!(item, syn::Item::Impl(_) | syn::Item::Trait(_)) { result_to_tokens(pyclass::impl_pyclass_impl(attr, item)) } else { @@ -46,7 +47,7 @@ pub fn pyclass(attr: AttributeArgs, item: Item) -> TokenStream { } } -pub fn pyexception(attr: AttributeArgs, item: Item) -> TokenStream { +pub fn pyexception(attr: PunctuatedNestedMeta, item: Item) -> TokenStream { if matches!(item, syn::Item::Impl(_)) { result_to_tokens(pyclass::impl_pyexception_impl(attr, item)) } else { @@ -54,7 +55,7 @@ pub fn pyexception(attr: AttributeArgs, item: Item) -> TokenStream { } } -pub fn pymodule(attr: AttributeArgs, item: Item) -> TokenStream { +pub fn pymodule(attr: PunctuatedNestedMeta, item: Item) -> TokenStream { result_to_tokens(pymodule::impl_pymodule(attr, item)) } diff --git a/derive-impl/src/pyclass.rs b/derive-impl/src/pyclass.rs index dedb3eb77a..077bd36bb8 100644 --- a/derive-impl/src/pyclass.rs +++ b/derive-impl/src/pyclass.rs @@ -1,17 +1,16 @@ use super::Diagnostic; use crate::util::{ - format_doc, pyclass_ident_and_attrs, pyexception_ident_and_attrs, text_signature, - ClassItemMeta, ContentItem, ContentItemInner, ErrorVec, ExceptionItemMeta, ItemMeta, - ItemMetaInner, ItemNursery, SimpleItemMeta, ALL_ALLOWED_NAMES, + ALL_ALLOWED_NAMES, ClassItemMeta, ContentItem, ContentItemInner, ErrorVec, ExceptionItemMeta, + ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta, format_doc, pyclass_ident_and_attrs, + pyexception_ident_and_attrs, text_signature, }; use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree}; -use quote::{quote, quote_spanned, ToTokens}; +use quote::{ToTokens, quote, quote_spanned}; use std::collections::{HashMap, HashSet}; use std::str::FromStr; -use syn::{ - parse_quote, spanned::Spanned, Attribute, AttributeArgs, Ident, Item, Meta, NestedMeta, Result, -}; +use syn::{Attribute, Ident, Item, Result, parse_quote, spanned::Spanned}; use syn_ext::ext::*; +use syn_ext::types::*; #[derive(Copy, Clone, Debug)] enum AttrName { @@ -98,7 +97,7 @@ fn extract_items_into_context<'a, Item>( context.errors.ok_or_push(context.member_items.validate()); } -pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result { +pub(crate) fn impl_pyclass_impl(attr: PunctuatedNestedMeta, item: Item) -> Result { let mut context = ImplContext::default(); let mut tokens = match item { Item::Impl(mut imp) => { @@ -127,7 +126,7 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result is expected but Py{Ref} is found", - )) + )); } } } @@ -135,7 +134,7 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result is expected but Py{Ref}? is found", - )) + )); } } } else { @@ -153,7 +152,7 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result or T", - )) + )); } }; @@ -235,9 +234,7 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result { - &method.sig.ident.to_string() == "extend_slots" - } + syn::TraitItem::Fn(item) => item.sig.ident == "extend_slots", _ => false, }; if has { @@ -344,7 +341,7 @@ fn generate_class_def( }; let basicsize = quote!(std::mem::size_of::<#ident>()); let is_pystruct = attrs.iter().any(|attr| { - attr.path.is_ident("derive") + attr.path().is_ident("derive") && if let Ok(Meta::List(l)) = attr.parse_meta() { l.nested .into_iter() @@ -418,7 +415,7 @@ fn generate_class_def( Ok(tokens) } -pub(crate) fn impl_pyclass(attr: AttributeArgs, item: Item) -> Result { +pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result { if matches!(item, syn::Item::Use(_)) { return Ok(quote!(#item)); } @@ -534,7 +531,7 @@ pub(crate) fn impl_pyclass(attr: AttributeArgs, item: Item) -> Result Result { +pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result { let (ident, _attrs) = pyexception_ident_and_attrs(&item)?; let fake_ident = Ident::new("pyclass", item.span()); let class_meta = ExceptionItemMeta::from_nested(ident.clone(), fake_ident, attr.into_iter())?; @@ -573,7 +570,7 @@ pub(crate) fn impl_pyexception(attr: AttributeArgs, item: Item) -> Result Result { +pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> Result { let Item::Impl(imp) = item else { return Ok(item.into_token_stream()); }; @@ -1229,15 +1226,18 @@ impl MethodItemMeta { let inner = self.inner(); let name = inner._optional_str("name")?; let magic = inner._bool("magic")?; + if magic && name.is_some() { + bail_span!( + &inner.meta_ident, + "A #[{}] method cannot be magic and have a specified name, choose one.", + inner.meta_name() + ); + } Ok(if let Some(name) = name { name } else { let name = inner.item_name(); - if magic { - format!("__{name}__") - } else { - name - } + if magic { format!("__{name}__") } else { name } }) } } @@ -1304,11 +1304,7 @@ impl GetSetItemMeta { GetSetItemKind::Set => extract_prefix_name("set_", "setter")?, GetSetItemKind::Delete => extract_prefix_name("del_", "deleter")?, }; - if magic { - format!("__{name}__") - } else { - name - } + if magic { format!("__{name}__") } else { name } }; Ok((py_name, kind)) } @@ -1447,7 +1443,7 @@ struct ExtractedImplAttrs { with_slots: TokenStream, } -fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result { +fn extract_impl_attrs(attr: PunctuatedNestedMeta, item: &Ident) -> Result { let mut withs = Vec::new(); let mut with_method_defs = Vec::new(); let mut with_slots = Vec::new(); @@ -1467,7 +1463,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result { + NestedMeta::Meta(Meta::List(MetaList { path, nested, .. })) => { if path.is_ident("with") { for meta in nested { let NestedMeta::Meta(Meta::Path(path)) = &meta else { @@ -1484,7 +1480,10 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result Result { + NestedMeta::Meta(Meta::NameValue(syn::MetaNameValue { path, value, .. })) => { if path.is_ident("payload") { - if let syn::Lit::Str(lit) = lit { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit), + .. + }) = value + { payload = Some(Ident::new(&lit.value(), lit.span())); } else { - bail_span!(lit, "payload must be a string literal") + bail_span!(value, "payload must be a string literal") } } else { bail_span!(path, "Unknown pyimpl attribute") diff --git a/derive-impl/src/pymodule.rs b/derive-impl/src/pymodule.rs index b9a59c8280..c59e40678e 100644 --- a/derive-impl/src/pymodule.rs +++ b/derive-impl/src/pymodule.rs @@ -1,14 +1,15 @@ use crate::error::Diagnostic; use crate::util::{ - format_doc, iter_use_idents, pyclass_ident_and_attrs, text_signature, AttrItemMeta, - AttributeExt, ClassItemMeta, ContentItem, ContentItemInner, ErrorVec, ItemMeta, ItemNursery, - ModuleItemMeta, SimpleItemMeta, ALL_ALLOWED_NAMES, + ALL_ALLOWED_NAMES, AttrItemMeta, AttributeExt, ClassItemMeta, ContentItem, ContentItemInner, + ErrorVec, ItemMeta, ItemNursery, ModuleItemMeta, SimpleItemMeta, format_doc, iter_use_idents, + pyclass_ident_and_attrs, text_signature, }; use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; -use quote::{quote, quote_spanned, ToTokens}; +use quote::{ToTokens, quote, quote_spanned}; use std::{collections::HashSet, str::FromStr}; -use syn::{parse_quote, spanned::Spanned, Attribute, AttributeArgs, Ident, Item, Result}; +use syn::{Attribute, Ident, Item, Result, parse_quote, spanned::Spanned}; use syn_ext::ext::*; +use syn_ext::types::PunctuatedNestedMeta; #[derive(Clone, Copy, Eq, PartialEq)] enum AttrName { @@ -51,7 +52,7 @@ struct ModuleContext { errors: Vec, } -pub fn impl_pymodule(attr: AttributeArgs, module_item: Item) -> Result { +pub fn impl_pymodule(attr: PunctuatedNestedMeta, module_item: Item) -> Result { let (doc, mut module_item) = match module_item { Item::Mod(m) => (m.attrs.doc(), m), other => bail_span!(other, "#[pymodule] can only be on a full module"), @@ -687,7 +688,7 @@ impl ModuleItem for AttributeItem { other => { return Err( self.new_syn_error(other.span(), "can only be on a function, const and use") - ) + ); } }; @@ -727,7 +728,7 @@ impl ModuleItem for AttributeItem { ( quote_spanned! { ident.span() => { #let_obj - for name in [(#(#names,)*)] { + for name in [#(#names),*] { vm.__module_set_attr(module, vm.ctx.intern_str(name), obj.clone()).unwrap(); } }}, diff --git a/derive-impl/src/pytraverse.rs b/derive-impl/src/pytraverse.rs index 93aa233a18..728722b83a 100644 --- a/derive-impl/src/pytraverse.rs +++ b/derive-impl/src/pytraverse.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{Attribute, DeriveInput, Field, Meta, MetaList, NestedMeta, Result}; +use syn::{Attribute, DeriveInput, Field, Result}; struct TraverseAttr { /// set to `true` if the attribute is `#[pytraverse(skip)]` @@ -9,47 +9,25 @@ struct TraverseAttr { const ATTR_TRAVERSE: &str = "pytraverse"; -/// get the `#[pytraverse(..)]` attribute from the struct -fn valid_get_traverse_attr_from_meta_list(list: &MetaList) -> Result { - let find_skip_and_only_skip = || { - let len = list.nested.len(); - if len != 1 { - return None; - } - let mut iter = list.nested.iter(); - // we have checked the length, so unwrap is safe - let first_arg = iter.next().unwrap(); - let skip = match first_arg { - NestedMeta::Meta(Meta::Path(path)) => match path.is_ident("skip") { - true => true, - false => return None, - }, - _ => return None, - }; - Some(skip) - }; - let skip = find_skip_and_only_skip().ok_or_else(|| { - err_span!( - list, - "only support attr is #[pytraverse(skip)], got arguments: {:?}", - list.nested - ) - })?; - Ok(TraverseAttr { skip }) -} - /// only accept `#[pytraverse(skip)]` for now fn pytraverse_arg(attr: &Attribute) -> Option> { - if !attr.path.is_ident(ATTR_TRAVERSE) { + if !attr.path().is_ident(ATTR_TRAVERSE) { return None; } let ret = || { - let parsed = attr.parse_meta()?; - if let Meta::List(list) = parsed { - valid_get_traverse_attr_from_meta_list(&list) - } else { - bail_span!(attr, "pytraverse must be a list, like #[pytraverse(skip)]") - } + let mut skip = false; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("skip") { + if skip { + return Err(meta.error("already specified skip")); + } + skip = true; + } else { + return Err(meta.error("unknown attr")); + } + Ok(()) + })?; + Ok(TraverseAttr { skip }) }; Some(ret()) } @@ -92,7 +70,7 @@ fn gen_trace_code(item: &mut DeriveInput) -> Result { syn::Data::Struct(s) => { let fields = &mut s.fields; match fields { - syn::Fields::Named(ref mut fields) => { + syn::Fields::Named(fields) => { let res: Vec = fields .named .iter_mut() diff --git a/derive-impl/src/util.rs b/derive-impl/src/util.rs index f016b0d1e9..7e0eb96fb0 100644 --- a/derive-impl/src/util.rs +++ b/derive-impl/src/util.rs @@ -1,13 +1,11 @@ use itertools::Itertools; use proc_macro2::{Span, TokenStream}; -use quote::{quote, ToTokens}; +use quote::{ToTokens, quote}; use std::collections::{HashMap, HashSet}; -use syn::{ - spanned::Spanned, Attribute, Ident, Meta, MetaList, NestedMeta, Result, Signature, UseTree, -}; +use syn::{Attribute, Ident, Result, Signature, UseTree, spanned::Spanned}; use syn_ext::{ ext::{AttributeExt as SynAttributeExt, *}, - types::PunctuatedNestedMeta, + types::*, }; pub(crate) const ALL_ALLOWED_NAMES: &[&str] = &[ @@ -77,7 +75,7 @@ impl ItemNursery { impl ToTokens for ValidatedItemNursery { fn to_tokens(&self, tokens: &mut TokenStream) { - let mut sorted = self.0 .0.clone(); + let mut sorted = self.0.0.clone(); sorted.sort_by(|a, b| a.sort_order.cmp(&b.sort_order)); tokens.extend(sorted.iter().map(|item| { let cfgs = &item.cfgs; @@ -167,7 +165,11 @@ impl ItemMetaInner { pub fn _optional_str(&self, key: &str) -> Result> { let value = if let Some((_, meta)) = self.meta_map.get(key) { let Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(lit), + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit), + .. + }), .. }) = meta else { @@ -193,7 +195,11 @@ impl ItemMetaInner { let value = if let Some((_, meta)) = self.meta_map.get(key) { match meta { Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Bool(lit), + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Bool(lit), + .. + }), .. }) => lit.value, Meta::Path(_) => true, @@ -210,7 +216,7 @@ impl ItemMetaInner { key: &str, ) -> Result>> { let value = if let Some((_, meta)) = self.meta_map.get(key) { - let Meta::List(syn::MetaList { + let Meta::List(MetaList { path: _, nested, .. }) = meta else { @@ -350,7 +356,11 @@ impl ClassItemMeta { if let Some((_, meta)) = inner.meta_map.get(KEY) { match meta { Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(lit), + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit), + .. + }), .. }) => return Ok(lit.value()), Meta::Path(_) => return Ok(inner.item_name()), @@ -387,11 +397,11 @@ impl ClassItemMeta { let value = if let Some((_, meta)) = inner.meta_map.get(KEY) { match meta { Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(lit), + value: syn::Expr::Lit(syn::ExprLit{lit:syn::Lit::Str(lit),..}), .. }) => Ok(Some(lit.value())), Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Bool(lit), + value: syn::Expr::Lit(syn::ExprLit{lit:syn::Lit::Bool(lit),..}), .. }) => if lit.value { Err(lit.span()) @@ -437,7 +447,7 @@ impl ItemMeta for ExceptionItemMeta { Self(ClassItemMeta(inner)) } fn inner(&self) -> &ItemMetaInner { - &self.0 .0 + &self.0.0 } } @@ -448,7 +458,11 @@ impl ExceptionItemMeta { if let Some((_, meta)) = inner.meta_map.get(KEY) { match meta { Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(lit), + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit), + .. + }), .. }) => return Ok(lit.value()), Meta::Path(_) => { @@ -456,12 +470,12 @@ impl ExceptionItemMeta { let type_name = inner.item_name(); let Some(py_name) = type_name.as_str().strip_prefix("Py") else { bail_span!( - inner.item_ident, - "#[pyexception] expects its underlying type to be named `Py` prefixed" - ) + inner.item_ident, + "#[pyexception] expects its underlying type to be named `Py` prefixed" + ) }; py_name.to_string() - }) + }); } _ => {} } @@ -489,7 +503,7 @@ impl std::ops::Deref for ExceptionItemMeta { pub(crate) trait AttributeExt: SynAttributeExt { fn promoted_nested(&self) -> Result; fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)>; - fn try_remove_name(&mut self, name: &str) -> Result>; + fn try_remove_name(&mut self, name: &str) -> Result>; fn fill_nested_meta(&mut self, name: &str, new_item: F) -> Result<()> where F: Fn() -> NestedMeta; @@ -512,10 +526,10 @@ impl AttributeExt for Attribute { Ok((self.get_ident().unwrap(), self.promoted_nested()?)) } - fn try_remove_name(&mut self, item_name: &str) -> Result> { + fn try_remove_name(&mut self, item_name: &str) -> Result> { self.try_meta_mut(|meta| { let nested = match meta { - Meta::List(MetaList { ref mut nested, .. }) => Ok(nested), + Meta::List(MetaList { nested, .. }) => Ok(nested), other => Err(syn::Error::new( other.span(), format!( diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 76b7e23488..97ca026e76 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -5,6 +5,7 @@ use proc_macro::TokenStream; use rustpython_derive_impl as derive_impl; use syn::parse_macro_input; +use syn::punctuated::Punctuated; #[proc_macro_derive(FromArgs, attributes(pyarg))] pub fn derive_from_args(input: TokenStream) -> TokenStream { @@ -12,9 +13,127 @@ pub fn derive_from_args(input: TokenStream) -> TokenStream { derive_impl::derive_from_args(input).into() } +/// The attribute can be applied either to a struct, trait, or impl. +/// # Struct +/// This implements `MaybeTraverse`, `PyClassDef`, and `StaticType` for the struct. +/// Consider deriving `Traverse` to implement it. +/// ## Arguments +/// - `module`: the module which contains the class -- can be omitted if in a `#[pymodule]`. +/// - `name`: the name of the Python class, by default it is the name of the struct. +/// - `base`: the base class of the Python class. +/// This does not cause inheritance of functions or attributes that must be done by a separate trait. +/// # Impl +/// This part implements `PyClassImpl` for the struct. +/// This includes methods, getters/setters, etc.; only annotated methods will be included. +/// Common functions and abilities like instantiation and `__call__` are often implemented by +/// traits rather than in the `impl` itself; see `Constructor` and `Callable` respectively for those. +/// ## Arguments +/// - `name`: the name of the Python class, when no name is provided the struct name is used. +/// - `flags`: the flags of the class, see `PyTypeFlags`. +/// - `BASETYPE`: allows the class to be inheritable. +/// - `IMMUTABLETYPE`: class attributes are immutable. +/// - `with`: which trait implementations are to be included in the python class. +/// ```rust, ignore +/// #[pyclass(module = "mymodule", name = "MyClass", base = "BaseClass")] +/// struct MyStruct { +/// x: i32, +/// } +/// +/// impl Constructor for MyStruct { +/// ... +/// } +/// +/// #[pyclass(with(Constructor))] +/// impl MyStruct { +/// ... +/// } +/// ``` +/// ## Inner markers +/// ### pymethod/pyclassmethod/pystaticmethod +/// `pymethod` is used to mark a method of the Python class. +/// `pyclassmethod` is used to mark a class method. +/// `pystaticmethod` is used to mark a static method. +/// #### Method signature +/// The first parameter can be either `&self` or `: PyRef` for `pymethod`. +/// The first parameter can be `cls: PyTypeRef` for `pyclassmethod`. +/// There is no mandatory parameter for `pystaticmethod`. +/// Both are valid and essentially the same, but the latter can yield more control. +/// The last parameter can optionally be of the type `&VirtualMachine` to access the VM. +/// All other values must implement `IntoPyResult`. +/// Numeric types, `String`, `bool`, and `PyObjectRef` implement this trait, +/// but so does any object that implements `PyValue`. +/// Consider using `OptionalArg` for optional arguments. +/// #### Arguments +/// - `magic`: marks the method as a magic method: the method name is surrounded with double underscores. +/// ```rust, ignore +/// #[pyclass] +/// impl MyStruct { +/// // This will be called as the `__add__` method in Python. +/// #[pymethod(magic)] +/// fn add(&self, other: &Self) -> PyResult { +/// ... +/// } +/// } +/// ``` +/// - `name`: the name of the method in Python, +/// by default it is the same as the Rust method, or surrounded by double underscores if magic is present. +/// This overrides `magic` and the default name and cannot be used with `magic` to prevent ambiguity. +/// ### pygetset +/// This is used to mark a getter/setter pair. +/// #### Arguments +/// - `setter`: marks the method as a setter, it acts as a getter by default. +/// Setter method names should be prefixed with `set_`. +/// - `name`: the name of the attribute in Python, by default it is the same as the Rust method. +/// - `magic`: marks the method as a magic method: the method name is surrounded with double underscores. +/// This cannot be used with `name` to prevent ambiguity. +/// +/// Ensure both the getter and setter are marked with `name` and `magic` in the same manner. +/// #### Examples +/// ```rust, ignore +/// #[pyclass] +/// impl MyStruct { +/// #[pygetset] +/// fn x(&self) -> PyResult { +/// Ok(self.x.lock()) +/// } +/// #[pygetset(setter)] +/// fn set_x(&mut self, value: i32) -> PyResult<()> { +/// self.x.set(value); +/// Ok(()) +/// } +/// } +/// ``` +/// ### pyslot +/// This is used to mark a slot method it should be marked by prefixing the method in rust with `slot_`. +/// #### Arguments +/// - name: the name of the slot method. +/// ### pyattr +/// ### extend_class +/// This helps inherit attributes from a parent class. +/// The method this is applied on should be called `extend_class_with_fields`. +/// #### Examples +/// ```rust, ignore +/// #[extend_class] +/// fn extend_class_with_fields(ctx: &Context, class: &'static Py) { +/// class.set_attr( +/// identifier!(ctx, _fields), +/// ctx.new_tuple(vec![ +/// ctx.new_str(ascii!("body")).into(), +/// ctx.new_str(ascii!("type_ignores")).into(), +/// ]) +/// .into(), +/// ); +/// class.set_attr(identifier!(ctx, _attributes), ctx.new_list(vec![]).into()); +/// } +/// ``` +/// ### pymember +/// # Trait +/// `#[pyclass]` on traits functions a lot like `#[pyclass]` on `impl` blocks. +/// Note that associated functions that are annotated with `#[pymethod]` or similar **must** +/// have a body, abstract functions should be wrapped before applying an annotation. #[proc_macro_attribute] pub fn pyclass(attr: TokenStream, item: TokenStream) -> TokenStream { - let attr = parse_macro_input!(attr); + let attr = parse_macro_input!(attr with Punctuated::parse_terminated); let item = parse_macro_input!(item); derive_impl::pyclass(attr, item).into() } @@ -29,14 +148,79 @@ pub fn pyclass(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[proc_macro_attribute] pub fn pyexception(attr: TokenStream, item: TokenStream) -> TokenStream { - let attr = parse_macro_input!(attr); + let attr = parse_macro_input!(attr with Punctuated::parse_terminated); let item = parse_macro_input!(item); derive_impl::pyexception(attr, item).into() } +/// This attribute must be applied to an inline module. +/// It defines a Python module in the form a `make_module` function in the module; +/// this has to be used in a `get_module_inits` to properly register the module. +/// Additionally, this macro defines 'MODULE_NAME' and 'DOC' in the module. +/// # Arguments +/// - `name`: the name of the python module, +/// by default, it is the name of the module, but this can be configured. +/// ```rust, ignore +/// // This will create a module named `mymodule` +/// #[pymodule(name = "mymodule")] +/// mod module { +/// } +/// ``` +/// - `sub`: declare the module as a submodule of another module. +/// ```rust, ignore +/// #[pymodule(sub)] +/// mod submodule { +/// } +/// +/// #[pymodule(with(submodule))] +/// mod mymodule { +/// } +/// ``` +/// - `with`: declare the list of submodules that this module contains (see `sub` for example). +/// ## Inner markers +/// ### pyattr +/// `pyattr` is a multipurpose marker that can be used in a pymodule. +/// The most common use is to mark a function or class as a part of the module. +/// This can be done by applying it to a function or struct prior to the `#[pyfunction]` or `#[pyclass]` macro. +/// If applied to a constant, it will be added to the module as an attribute. +/// If applied to a function not marked with `pyfunction`, +/// it will also be added to the module as an attribute but the value is the result of the function. +/// If `#[pyattr(once)]` is used in this case, the function will be called once +/// and the result will be stored using a `static_cell`. +/// #### Examples +/// ```rust, ignore +/// #[pymodule] +/// mod mymodule { +/// #[pyattr] +/// const MY_CONSTANT: i32 = 42; +/// #[pyattr] +/// fn another_constant() -> PyResult { +/// Ok(42) +/// } +/// #[pyattr(once)] +/// fn once() -> PyResult { +/// // This will only be called once and the result will be stored. +/// Ok(2 ** 24) +/// } +/// +/// #[pyattr] +/// #[pyfunction] +/// fn my_function(vm: &VirtualMachine) -> PyResult<()> { +/// ... +/// } +/// } +/// ``` +/// ### pyfunction +/// This is used to create a python function. +/// #### Function signature +/// The last argument can optionally be of the type `&VirtualMachine` to access the VM. +/// Refer to the `pymethod` documentation (located in the `pyclass` macro documentation) +/// for more information on what regular argument types are permitted. +/// #### Arguments +/// - `name`: the name of the function in Python, by default it is the same as the associated Rust function. #[proc_macro_attribute] pub fn pymodule(attr: TokenStream, item: TokenStream) -> TokenStream { - let attr = parse_macro_input!(attr); + let attr = parse_macro_input!(attr with Punctuated::parse_terminated); let item = parse_macro_input!(item); derive_impl::pymodule(attr, item).into() } @@ -61,7 +245,7 @@ impl derive_impl::Compiler for Compiler { mode: rustpython_compiler::Mode, module_name: String, ) -> Result> { - use rustpython_compiler::{compile, CompileOpts}; + use rustpython_compiler::{CompileOpts, compile}; Ok(compile(source, mode, module_name, CompileOpts::default())?) } } diff --git a/examples/call_between_rust_and_python.rs b/examples/call_between_rust_and_python.rs index 78eef62200..576390d059 100644 --- a/examples/call_between_rust_and_python.rs +++ b/examples/call_between_rust_and_python.rs @@ -1,5 +1,5 @@ use rustpython::vm::{ - pyclass, pymodule, PyObject, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, + PyObject, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, pyclass, pymodule, }; pub fn main() { @@ -31,7 +31,7 @@ pub fn main() { #[pymodule] mod rust_py_module { use super::*; - use rustpython::vm::{builtins::PyList, convert::ToPyObject, PyObjectRef}; + use rustpython::vm::{PyObjectRef, builtins::PyList, convert::ToPyObject}; #[pyfunction] fn rust_function( diff --git a/examples/generator.rs b/examples/generator.rs index 010ccd3797..937687ab8f 100644 --- a/examples/generator.rs +++ b/examples/generator.rs @@ -1,9 +1,9 @@ use rustpython_vm as vm; use std::process::ExitCode; use vm::{ + Interpreter, PyResult, builtins::PyIntRef, protocol::{PyIter, PyIterReturn}, - Interpreter, PyResult, }; fn py_main(interp: &Interpreter) -> vm::PyResult<()> { diff --git a/examples/package_embed.rs b/examples/package_embed.rs index fffa98623a..975e734593 100644 --- a/examples/package_embed.rs +++ b/examples/package_embed.rs @@ -1,6 +1,6 @@ use rustpython_vm as vm; use std::process::ExitCode; -use vm::{builtins::PyStrRef, Interpreter}; +use vm::{Interpreter, builtins::PyStrRef}; fn py_main(interp: &Interpreter) -> vm::PyResult { interp.enter(|vm| { diff --git a/examples/parse_folder.rs b/examples/parse_folder.rs index c10450e018..f54be635c8 100644 --- a/examples/parse_folder.rs +++ b/examples/parse_folder.rs @@ -12,7 +12,7 @@ extern crate env_logger; extern crate log; use clap::{App, Arg}; -use rustpython_parser::{ast, Parse}; +use rustpython_parser::{Parse, ast}; use std::{ path::Path, time::{Duration, Instant}, diff --git a/extra_tests/snippets/builtin_object.py b/extra_tests/snippets/builtin_object.py index ef83da83e2..5a12afbf45 100644 --- a/extra_tests/snippets/builtin_object.py +++ b/extra_tests/snippets/builtin_object.py @@ -7,9 +7,7 @@ class MyObject: assert myobj == myobj assert not myobj != myobj -object.__subclasshook__() == NotImplemented object.__subclasshook__(1) == NotImplemented -object.__subclasshook__(1, 2) == NotImplemented assert MyObject().__eq__(MyObject()) == NotImplemented assert MyObject().__ne__(MyObject()) == NotImplemented diff --git a/extra_tests/snippets/builtins_ctypes.py b/extra_tests/snippets/builtins_ctypes.py new file mode 100644 index 0000000000..5bd6e5ef25 --- /dev/null +++ b/extra_tests/snippets/builtins_ctypes.py @@ -0,0 +1,133 @@ +import os as _os, sys as _sys + +from _ctypes import sizeof +from _ctypes import _SimpleCData +from struct import calcsize as _calcsize + +def create_string_buffer(init, size=None): + """create_string_buffer(aBytes) -> character array + create_string_buffer(anInteger) -> character array + create_string_buffer(aBytes, anInteger) -> character array + """ + if isinstance(init, bytes): + if size is None: + size = len(init)+1 + _sys.audit("ctypes.create_string_buffer", init, size) + buftype = c_char * size + buf = buftype() + buf.value = init + return buf + elif isinstance(init, int): + _sys.audit("ctypes.create_string_buffer", None, init) + buftype = c_char * init + buf = buftype() + return buf + raise TypeError(init) + +def _check_size(typ, typecode=None): + # Check if sizeof(ctypes_type) against struct.calcsize. This + # should protect somewhat against a misconfigured libffi. + from struct import calcsize + if typecode is None: + # Most _type_ codes are the same as used in struct + typecode = typ._type_ + actual, required = sizeof(typ), calcsize(typecode) + if actual != required: + raise SystemError("sizeof(%s) wrong: %d instead of %d" % \ + (typ, actual, required)) + +class c_short(_SimpleCData): + _type_ = "h" +_check_size(c_short) + +class c_ushort(_SimpleCData): + _type_ = "H" +_check_size(c_ushort) + +class c_long(_SimpleCData): + _type_ = "l" +_check_size(c_long) + +class c_ulong(_SimpleCData): + _type_ = "L" +_check_size(c_ulong) + +if _calcsize("i") == _calcsize("l"): + # if int and long have the same size, make c_int an alias for c_long + c_int = c_long + c_uint = c_ulong +else: + class c_int(_SimpleCData): + _type_ = "i" + _check_size(c_int) + + class c_uint(_SimpleCData): + _type_ = "I" + _check_size(c_uint) + +class c_float(_SimpleCData): + _type_ = "f" +_check_size(c_float) + +class c_double(_SimpleCData): + _type_ = "d" +_check_size(c_double) + +class c_longdouble(_SimpleCData): + _type_ = "g" +if sizeof(c_longdouble) == sizeof(c_double): + c_longdouble = c_double + +if _calcsize("l") == _calcsize("q"): + # if long and long long have the same size, make c_longlong an alias for c_long + c_longlong = c_long + c_ulonglong = c_ulong +else: + class c_longlong(_SimpleCData): + _type_ = "q" + _check_size(c_longlong) + + class c_ulonglong(_SimpleCData): + _type_ = "Q" + ## def from_param(cls, val): + ## return ('d', float(val), val) + ## from_param = classmethod(from_param) + _check_size(c_ulonglong) + +class c_ubyte(_SimpleCData): + _type_ = "B" +c_ubyte.__ctype_le__ = c_ubyte.__ctype_be__ = c_ubyte +# backward compatibility: +##c_uchar = c_ubyte +_check_size(c_ubyte) + +class c_byte(_SimpleCData): + _type_ = "b" +c_byte.__ctype_le__ = c_byte.__ctype_be__ = c_byte +_check_size(c_byte) + +class c_char(_SimpleCData): + _type_ = "c" +c_char.__ctype_le__ = c_char.__ctype_be__ = c_char +_check_size(c_char) + +class c_char_p(_SimpleCData): + _type_ = "z" + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, c_void_p.from_buffer(self).value) +_check_size(c_char_p, "P") + +class c_void_p(_SimpleCData): + _type_ = "P" +c_voidp = c_void_p # backwards compatibility (to a bug) +_check_size(c_void_p) + +class c_bool(_SimpleCData): + _type_ = "?" +_check_size(c_bool) + +i = c_int(42) +f = c_float(3.14) +# s = create_string_buffer(b'\000' * 32) +assert i.value == 42 +assert abs(f.value - 3.14) < 1e-06 diff --git a/extra_tests/snippets/stdlib_imghdr.py b/extra_tests/snippets/stdlib_imghdr.py deleted file mode 100644 index 5ca524e269..0000000000 --- a/extra_tests/snippets/stdlib_imghdr.py +++ /dev/null @@ -1,27 +0,0 @@ -# unittest for modified imghdr.py -# Should be replace it into https://github.com/python/cpython/blob/main/Lib/test/test_imghdr.py -import os -import imghdr - - -TEST_FILES = ( - #('python.png', 'png'), - ('python.gif', 'gif'), - ('python.bmp', 'bmp'), - ('python.ppm', 'ppm'), - ('python.pgm', 'pgm'), - ('python.pbm', 'pbm'), - ('python.jpg', 'jpeg'), - ('python.ras', 'rast'), - #('python.sgi', 'rgb'), - ('python.tiff', 'tiff'), - ('python.xbm', 'xbm'), - ('python.webp', 'webp'), - ('python.exr', 'exr'), -) - -resource_dir = os.path.join(os.path.dirname(__file__), 'imghdrdata') - -for fname, expected in TEST_FILES: - res = imghdr.what(os.path.join(resource_dir, fname)) - assert res == expected \ No newline at end of file diff --git a/extra_tests/snippets/stdlib_math.py b/extra_tests/snippets/stdlib_math.py index 442bbc97a5..94d8c7347c 100644 --- a/extra_tests/snippets/stdlib_math.py +++ b/extra_tests/snippets/stdlib_math.py @@ -291,3 +291,5 @@ def assertAllNotClose(examples, *args, **kwargs): assert math.fmod(-3.0, NINF) == -3.0 assert math.fmod(0.0, 3.0) == 0.0 assert math.fmod(0.0, NINF) == 0.0 + +assert math.gamma(1) == 1.0 diff --git a/extra_tests/snippets/stdlib_subprocess.py b/extra_tests/snippets/stdlib_subprocess.py index 96ead76583..2e3aa7b2c1 100644 --- a/extra_tests/snippets/stdlib_subprocess.py +++ b/extra_tests/snippets/stdlib_subprocess.py @@ -16,7 +16,7 @@ def echo(text): return ["cmd", "/C", f"echo {text}"] def sleep(secs): # TODO: make work in a non-unixy environment (something with timeout.exe?) - return ["sleep", str(secs)] + return ["powershell", "/C", "sleep", str(secs)] p = subprocess.Popen(echo("test")) @@ -32,7 +32,7 @@ def sleep(secs): assert p.poll() is None with assert_raises(subprocess.TimeoutExpired): - assert p.wait(1) + assert p.wait(1) p.wait() @@ -48,17 +48,17 @@ def sleep(secs): p.terminate() p.wait() if is_unix: - assert p.returncode == -signal.SIGTERM + assert p.returncode == -signal.SIGTERM else: - assert p.returncode == 1 + assert p.returncode == 1 p = subprocess.Popen(sleep(2)) p.kill() p.wait() if is_unix: - assert p.returncode == -signal.SIGKILL + assert p.returncode == -signal.SIGKILL else: - assert p.returncode == 1 + assert p.returncode == 1 p = subprocess.Popen(echo("test"), stdout=subprocess.PIPE) (stdout, stderr) = p.communicate() @@ -66,4 +66,4 @@ def sleep(secs): p = subprocess.Popen(sleep(5), stdout=subprocess.PIPE) with assert_raises(subprocess.TimeoutExpired): - p.communicate(timeout=1) + p.communicate(timeout=1) diff --git a/extra_tests/snippets/stdlib_xdrlib.py b/extra_tests/snippets/stdlib_xdrlib.py deleted file mode 100644 index 681cd77467..0000000000 --- a/extra_tests/snippets/stdlib_xdrlib.py +++ /dev/null @@ -1,12 +0,0 @@ -# This probably will be superceeded by the python unittests when that works. - -import xdrlib - -p = xdrlib.Packer() -p.pack_int(1337) - -d = p.get_buffer() - -print(d) - -# assert d == b'\x00\x00\x059' diff --git a/extra_tests/snippets/syntax_class.py b/extra_tests/snippets/syntax_class.py index 28d066a9e9..fb702f3304 100644 --- a/extra_tests/snippets/syntax_class.py +++ b/extra_tests/snippets/syntax_class.py @@ -50,7 +50,7 @@ def kungfu(x): assert x == 3 -assert Bar.__doc__ == " W00t " +assert Bar.__doc__ == "W00t " bar = Bar(42) assert bar.get_x.__doc__ == None @@ -147,7 +147,7 @@ class T3: test3 """ -assert T3.__doc__ == "\n test3\n " +assert T3.__doc__ == "\ntest3\n" class T4: diff --git a/extra_tests/snippets/syntax_doc.py b/extra_tests/snippets/syntax_doc.py new file mode 100644 index 0000000000..bdfc1fe778 --- /dev/null +++ b/extra_tests/snippets/syntax_doc.py @@ -0,0 +1,15 @@ + +def f1(): + """ + x + \ty + """ +assert f1.__doc__ == '\nx\ny\n' + +def f2(): + """ +\t x +\t\ty + """ + +assert f2.__doc__ == '\nx\n y\n' diff --git a/extra_tests/snippets/syntax_function2.py b/extra_tests/snippets/syntax_function2.py index ebea34fe58..dce4cb54eb 100644 --- a/extra_tests/snippets/syntax_function2.py +++ b/extra_tests/snippets/syntax_function2.py @@ -44,7 +44,7 @@ def f3(): """ pass -assert f3.__doc__ == "\n test3\n " +assert f3.__doc__ == "\ntest3\n" def f4(): "test4" diff --git a/jit/src/lib.rs b/jit/src/lib.rs index 99bfb45c78..37f1f2a3dd 100644 --- a/jit/src/lib.rs +++ b/jit/src/lib.rs @@ -152,12 +152,14 @@ impl CompiledCode { } unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg]) -> Option { - let cif = self.sig.to_cif(); - let value = cif.call::( - libffi::middle::CodePtr::from_ptr(self.code as *const _), - cif_args, - ); - self.sig.ret.as_ref().map(|ty| value.to_typed(ty)) + unsafe { + let cif = self.sig.to_cif(); + let value = cif.call::( + libffi::middle::CodePtr::from_ptr(self.code as *const _), + cif_args, + ); + self.sig.ret.as_ref().map(|ty| value.to_typed(ty)) + } } } @@ -213,9 +215,9 @@ pub enum AbiValue { impl AbiValue { fn to_libffi_arg(&self) -> libffi::middle::Arg { match self { - AbiValue::Int(ref i) => libffi::middle::Arg::new(i), - AbiValue::Float(ref f) => libffi::middle::Arg::new(f), - AbiValue::Bool(ref b) => libffi::middle::Arg::new(b), + AbiValue::Int(i) => libffi::middle::Arg::new(i), + AbiValue::Float(f) => libffi::middle::Arg::new(f), + AbiValue::Bool(b) => libffi::middle::Arg::new(b), } } } @@ -290,10 +292,12 @@ union UnTypedAbiValue { impl UnTypedAbiValue { unsafe fn to_typed(self, ty: &JitType) -> AbiValue { - match ty { - JitType::Int => AbiValue::Int(self.int), - JitType::Float => AbiValue::Float(self.float), - JitType::Bool => AbiValue::Bool(self.boolean != 0), + unsafe { + match ty { + JitType::Int => AbiValue::Int(self.int), + JitType::Float => AbiValue::Float(self.float), + JitType::Bool => AbiValue::Bool(self.boolean != 0), + } } } } diff --git a/rustfmt.toml b/rustfmt.toml index 3a26366d4d..f216078d96 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1 @@ -edition = "2021" +edition = "2024" diff --git a/src/interpreter.rs b/src/interpreter.rs index 9d4dc4ee98..35710ae829 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -1,4 +1,4 @@ -use rustpython_vm::{builtins::PyModule, Interpreter, PyRef, Settings, VirtualMachine}; +use rustpython_vm::{Interpreter, PyRef, Settings, VirtualMachine, builtins::PyModule}; pub type InitHook = Box; diff --git a/src/lib.rs b/src/lib.rs index 0e35a6ad83..b0a176acf2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,14 +50,14 @@ mod interpreter; mod settings; mod shell; -use rustpython_vm::{scope::Scope, PyResult, VirtualMachine}; +use rustpython_vm::{PyResult, VirtualMachine, scope::Scope}; use std::env; use std::io::IsTerminal; use std::process::ExitCode; pub use interpreter::InterpreterConfig; pub use rustpython_vm as vm; -pub use settings::{opts_with_clap, InstallPipMode, RunMode}; +pub use settings::{InstallPipMode, RunMode, opts_with_clap}; pub use shell::run_shell; /// The main cli of the `rustpython` interpreter. This function will return `std::process::ExitCode` @@ -78,7 +78,7 @@ pub fn run(init: impl FnOnce(&mut VirtualMachine) + 'static) -> ExitCode { // don't translate newlines (\r\n <=> \n) #[cfg(windows)] { - extern "C" { + unsafe extern "C" { fn _setmode(fd: i32, flags: i32) -> i32; } unsafe { diff --git a/src/settings.rs b/src/settings.rs index f6b7f21f25..35114374c8 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -238,7 +238,9 @@ fn settings_from(matches: &ArgMatches) -> (Settings, RunMode) { settings.int_max_str_digits = match env::var("PYTHONINTMAXSTRDIGITS").unwrap().parse() { Ok(digits @ (0 | 640..)) => digits, _ => { - error!("Fatal Python error: config_init_int_max_str_digits: PYTHONINTMAXSTRDIGITS: invalid limit; must be >= 640 or 0 for unlimited.\nPython runtime state: preinitialized"); + error!( + "Fatal Python error: config_init_int_max_str_digits: PYTHONINTMAXSTRDIGITS: invalid limit; must be >= 640 or 0 for unlimited.\nPython runtime state: preinitialized" + ); std::process::exit(1); } }; diff --git a/src/shell.rs b/src/shell.rs index 91becb0bbf..00b6710061 100644 --- a/src/shell.rs +++ b/src/shell.rs @@ -1,12 +1,12 @@ mod helper; -use rustpython_parser::{lexer::LexicalErrorType, ParseErrorType, Tok}; +use rustpython_parser::{ParseErrorType, Tok, lexer::LexicalErrorType}; use rustpython_vm::{ + AsObject, PyResult, VirtualMachine, builtins::PyBaseExceptionRef, compiler::{self, CompileError, CompileErrorType}, readline::{Readline, ReadlineResult}, scope::Scope, - AsObject, PyResult, VirtualMachine, }; enum ShellExecResult { diff --git a/src/shell/helper.rs b/src/shell/helper.rs index c228dbdf65..0fe2f5ca93 100644 --- a/src/shell/helper.rs +++ b/src/shell/helper.rs @@ -1,8 +1,9 @@ #![cfg_attr(target_arch = "wasm32", allow(dead_code))] use rustpython_vm::{ + AsObject, PyResult, TryFromObject, VirtualMachine, builtins::{PyDictRef, PyStrRef}, function::ArgIterable, - identifier, AsObject, PyResult, TryFromObject, VirtualMachine, + identifier, }; pub struct ShellHelper<'vm> { @@ -107,7 +108,7 @@ impl<'vm> ShellHelper<'vm> { .filter(|res| { res.as_ref() .ok() - .map_or(true, |s| s.as_str().starts_with(word_start)) + .is_none_or(|s| s.as_str().starts_with(word_start)) }) .collect::, _>>() .ok()?; diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 934c9b5cf1..1086dea090 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -14,11 +14,10 @@ license.workspace = true default = ["compiler"] compiler = ["rustpython-vm/compiler"] threading = ["rustpython-common/threading", "rustpython-vm/threading"] -zlib = ["libz-sys", "flate2/zlib"] bz2 = ["bzip2"] sqlite = ["dep:libsqlite3-sys"] -ssl = ["openssl", "openssl-sys", "foreign-types-shared"] -ssl-vendor = ["ssl", "openssl/vendored", "openssl-probe"] +ssl = ["openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"] +ssl-vendor = ["ssl", "openssl/vendored"] [dependencies] # rustpython crates @@ -46,16 +45,14 @@ thread_local = { workspace = true } memchr = { workspace = true } base64 = "0.13.0" -csv-core = "0.1.10" +csv-core = "0.1.11" dyn-clone = "1.0.10" -libz-sys = { version = "1.1", default-features = false, optional = true } -puruspe = "0.2.4" +puruspe = "0.4.0" xml-rs = "0.8.14" # random rand = { workspace = true } -rand_core = "0.6.3" -mt19937 = "2.0.1" +mt19937 = "3.1" # Crypto: digest = "0.10.3" @@ -82,18 +79,19 @@ ucd = "0.1.1" # compression adler32 = "1.2.0" crc32fast = "1.3.2" -flate2 = "1.0.28" +flate2 = { version = "1.1", default-features = false, features = ["zlib-rs"] } +libz-sys = { package = "libz-rs-sys", version = "0.4" } bzip2 = { version = "0.4", optional = true } # uuid [target.'cfg(not(any(target_os = "ios", target_os = "android", target_os = "windows", target_arch = "wasm32", target_os = "redox")))'.dependencies] mac_address = "1.1.3" -uuid = { version = "1.1.2", features = ["v1", "fast-rng"] } +uuid = { version = "1.1.2", features = ["v1"] } # mmap [target.'cfg(all(unix, not(target_arch = "wasm32")))'.dependencies] -memmap2 = "0.5.4" -page_size = "0.4" +memmap2 = "0.5.10" +page_size = "0.6" [target.'cfg(all(unix, not(target_os = "redox"), not(target_os = "ios")))'.dependencies] termios = "0.3.3" @@ -102,8 +100,8 @@ termios = "0.3.3" rustix = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -gethostname = "0.2.3" -socket2 = { version = "0.5.6", features = ["all"] } +gethostname = "1.0.0" +socket2 = { version = "0.5.8", features = ["all"] } dns-lookup = "2" openssl = { version = "0.10.66", optional = true } openssl-sys = { version = "0.9.80", optional = true } @@ -136,7 +134,7 @@ features = [ ] [target.'cfg(target_os = "macos")'.dependencies] -system-configuration = "0.5.0" +system-configuration = "0.5.1" [lints] workspace = true diff --git a/stdlib/build.rs b/stdlib/build.rs index 0109c633d7..3eb8a2d6b6 100644 --- a/stdlib/build.rs +++ b/stdlib/build.rs @@ -1,5 +1,6 @@ fn main() { println!(r#"cargo::rustc-check-cfg=cfg(osslconf, values("OPENSSL_NO_COMP"))"#); + println!(r#"cargo::rustc-check-cfg=cfg(openssl_vendored)"#); #[allow(clippy::unusual_byte_groupings)] let ossl_vers = [ @@ -36,4 +37,9 @@ fn main() { println!("cargo:rustc-cfg=osslconf=\"{conf}\""); } } + // it's possible for openssl-sys to link against the system openssl under certain conditions, + // so let the ssl module know to only perform a probe if we're actually vendored + if std::env::var("DEP_OPENSSL_VENDORED").is_ok_and(|s| s == "1") { + println!("cargo::rustc-cfg=openssl_vendored") + } } diff --git a/stdlib/src/array.rs b/stdlib/src/array.rs index 9bd58f8043..494e98dc1b 100644 --- a/stdlib/src/array.rs +++ b/stdlib/src/array.rs @@ -1,6 +1,6 @@ // spell-checker:ignore typecode tofile tolist fromfile -use rustpython_vm::{builtins::PyModule, PyRef, VirtualMachine}; +use rustpython_vm::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let module = array::make_module(vm); @@ -41,6 +41,7 @@ mod array { str::wchar_t, }, vm::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, builtins::{ PositionIterInternal, PyByteArray, PyBytes, PyBytesRef, PyDictRef, PyFloat, PyInt, @@ -64,7 +65,6 @@ mod array { AsBuffer, AsMapping, AsSequence, Comparable, Constructor, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, }, - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }, }; use itertools::Itertools; @@ -1528,7 +1528,7 @@ mod array { 2 => Some(Self::Utf16 { big_endian }), 4 => Some(Self::Utf32 { big_endian }), _ => None, - } + }; } 'f' => { // Copied from CPython diff --git a/stdlib/src/binascii.rs b/stdlib/src/binascii.rs index d348049f4c..ce3f9febd5 100644 --- a/stdlib/src/binascii.rs +++ b/stdlib/src/binascii.rs @@ -2,7 +2,7 @@ pub(super) use decl::crc32; pub(crate) use decl::make_module; -use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine}; +use rustpython_vm::{VirtualMachine, builtins::PyBaseExceptionRef, convert::ToPyException}; const PAD: u8 = 61u8; const MAXLINESIZE: usize = 76; // Excluding the CRLF @@ -11,10 +11,10 @@ const MAXLINESIZE: usize = 76; // Excluding the CRLF mod decl { use super::{MAXLINESIZE, PAD}; use crate::vm::{ + PyResult, VirtualMachine, builtins::{PyIntRef, PyTypeRef}, convert::ToPyException, function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg}, - PyResult, VirtualMachine, }; use itertools::Itertools; @@ -76,7 +76,7 @@ mod decl { let mut unhex = Vec::::with_capacity(hex_bytes.len() / 2); for (n1, n2) in hex_bytes.iter().tuples() { if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) { - unhex.push(n1 << 4 | n2); + unhex.push((n1 << 4) | n2); } else { return Err(super::new_binascii_error( "Non-hexadecimal digit found".to_owned(), @@ -343,7 +343,7 @@ mod decl { if let (Some(ch1), Some(ch2)) = (unhex_nibble(buffer[idx]), unhex_nibble(buffer[idx + 1])) { - out_data.push(ch1 << 4 | ch2); + out_data.push((ch1 << 4) | ch2); } idx += 2; } else { @@ -661,19 +661,19 @@ mod decl { }; if res.len() < length { - res.push(char_a << 2 | char_b >> 4); + res.push((char_a << 2) | (char_b >> 4)); } else if char_a != 0 || char_b != 0 { return trailing_garbage_error(); } if res.len() < length { - res.push((char_b & 0xf) << 4 | char_c >> 2); + res.push(((char_b & 0xf) << 4) | (char_c >> 2)); } else if char_c != 0 { return trailing_garbage_error(); } if res.len() < length { - res.push((char_c & 0x3) << 6 | char_d); + res.push(((char_c & 0x3) << 6) | char_d); } else if char_d != 0 { return trailing_garbage_error(); } @@ -725,8 +725,8 @@ mod decl { let char_c = *chunk.get(2).unwrap_or(&0); res.push(uu_b2a(char_a >> 2, backtick)); - res.push(uu_b2a((char_a & 0x3) << 4 | char_b >> 4, backtick)); - res.push(uu_b2a((char_b & 0xf) << 2 | char_c >> 6, backtick)); + res.push(uu_b2a(((char_a & 0x3) << 4) | (char_b >> 4), backtick)); + res.push(uu_b2a(((char_b & 0xf) << 2) | (char_c >> 6), backtick)); res.push(uu_b2a(char_c & 0x3f, backtick)); } @@ -751,7 +751,10 @@ impl ToPyException for Base64DecodeError { InvalidByte(_, _) => "Only base64 data is allowed".to_owned(), InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(), InvalidLastSymbol(length, _) => { - format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length) + format!( + "Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", + length + ) } InvalidLength => "Incorrect padding".to_owned(), }; diff --git a/stdlib/src/bisect.rs b/stdlib/src/bisect.rs index aaab65d788..4d67ee50b9 100644 --- a/stdlib/src/bisect.rs +++ b/stdlib/src/bisect.rs @@ -3,9 +3,9 @@ pub(crate) use _bisect::make_module; #[pymodule] mod _bisect { use crate::vm::{ + PyObjectRef, PyResult, VirtualMachine, function::{ArgIndex, OptionalArg}, types::PyComparisonOp, - PyObjectRef, PyResult, VirtualMachine, }; #[derive(FromArgs)] diff --git a/stdlib/src/blake2.rs b/stdlib/src/blake2.rs index 9b7da3327c..4209c966e8 100644 --- a/stdlib/src/blake2.rs +++ b/stdlib/src/blake2.rs @@ -4,7 +4,7 @@ pub(crate) use _blake2::make_module; #[pymodule] mod _blake2 { - use crate::hashlib::_hashlib::{local_blake2b, local_blake2s, BlakeHashArgs}; + use crate::hashlib::_hashlib::{BlakeHashArgs, local_blake2b, local_blake2s}; use crate::vm::{PyPayload, PyResult, VirtualMachine}; #[pyfunction] diff --git a/stdlib/src/bz2.rs b/stdlib/src/bz2.rs index f150b06eb8..ba74a38db1 100644 --- a/stdlib/src/bz2.rs +++ b/stdlib/src/bz2.rs @@ -6,13 +6,13 @@ pub(crate) use _bz2::make_module; mod _bz2 { use crate::common::lock::PyMutex; use crate::vm::{ + VirtualMachine, builtins::{PyBytesRef, PyTypeRef}, function::{ArgBytesLike, OptionalArg}, object::{PyPayload, PyResult}, types::Constructor, - VirtualMachine, }; - use bzip2::{write::BzEncoder, Decompress, Status}; + use bzip2::{Decompress, Status, write::BzEncoder}; use std::{fmt, io::Write}; // const BUFSIZ: i32 = 8192; @@ -196,7 +196,7 @@ mod _bz2 { _ => { return Err( vm.new_value_error("compresslevel must be between 1 and 9".to_owned()) - ) + ); } }; diff --git a/stdlib/src/cmath.rs b/stdlib/src/cmath.rs index c5badcf72a..4611ea344e 100644 --- a/stdlib/src/cmath.rs +++ b/stdlib/src/cmath.rs @@ -4,8 +4,8 @@ pub(crate) use cmath::make_module; #[pymodule] mod cmath { use crate::vm::{ - function::{ArgIntoComplex, ArgIntoFloat, OptionalArg}, PyResult, VirtualMachine, + function::{ArgIntoComplex, ArgIntoFloat, OptionalArg}, }; use num_complex::Complex64; diff --git a/stdlib/src/contextvars.rs b/stdlib/src/contextvars.rs index 40a59050b3..1e27b8b9e5 100644 --- a/stdlib/src/contextvars.rs +++ b/stdlib/src/contextvars.rs @@ -1,4 +1,4 @@ -use crate::vm::{builtins::PyModule, class::StaticType, PyRef, VirtualMachine}; +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule, class::StaticType}; use _contextvars::PyContext; use std::cell::RefCell; @@ -23,14 +23,13 @@ thread_local! { #[pymodule] mod _contextvars { use crate::vm::{ - atomic_func, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, builtins::{PyStrRef, PyTypeRef}, class::StaticType, common::hash::PyHash, function::{ArgCallable, FuncArgs, OptionalArg}, protocol::{PyMappingMethods, PySequenceMethods}, types::{AsMapping, AsSequence, Constructor, Hashable, Representable}, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use indexmap::IndexMap; diff --git a/stdlib/src/csv.rs b/stdlib/src/csv.rs index 2dd7f2ae12..03a5429ba4 100644 --- a/stdlib/src/csv.rs +++ b/stdlib/src/csv.rs @@ -4,11 +4,11 @@ pub(crate) use _csv::make_module; mod _csv { use crate::common::lock::PyMutex; use crate::vm::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyInt, PyNone, PyStr, PyType, PyTypeError, PyTypeRef}, function::{ArgIterable, ArgumentError, FromArgs, FuncArgs, OptionalArg}, protocol::{PyIter, PyIterReturn}, types::{Constructor, IterNext, Iterable, SelfIter}, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use csv_core::Terminator; use itertools::{self, Itertools}; @@ -272,12 +272,12 @@ mod _csv { let Some(name) = name.payload_if_subclass::(vm) else { return Err(vm.new_type_error("argument 0 must be a string".to_string())); }; - let mut dialect = match dialect { + let dialect = match dialect { OptionalArg::Present(d) => PyDialect::try_from_object(vm, d) .map_err(|_| vm.new_type_error("argument 1 must be a dialect object".to_owned()))?, OptionalArg::Missing => opts.result(vm)?, }; - opts.update_pydialect(&mut dialect); + let dialect = opts.update_pydialect(dialect); GLOBAL_HASHMAP .lock() .insert(name.as_str().to_owned(), dialect); @@ -396,7 +396,7 @@ mod _csv { Some(write_meth) => write_meth, None if file.is_callable() => file, None => { - return Err(vm.new_type_error("argument 1 must have a \"write\" method".to_owned())) + return Err(vm.new_type_error("argument 1 must have a \"write\" method".to_owned())); } }; @@ -665,7 +665,7 @@ mod _csv { } impl FormatOptions { - fn update_pydialect<'b>(&self, res: &'b mut PyDialect) -> &'b mut PyDialect { + fn update_pydialect(&self, mut res: PyDialect) -> PyDialect { macro_rules! check_and_fill { ($res:ident, $e:ident) => {{ if let Some(t) = self.$e { @@ -699,24 +699,18 @@ mod _csv { DialectItem::Str(name) => { let g = GLOBAL_HASHMAP.lock(); if let Some(dialect) = g.get(name) { - let mut dialect = *dialect; - self.update_pydialect(&mut dialect); - Ok(dialect) + Ok(self.update_pydialect(*dialect)) } else { Err(new_csv_error(vm, format!("{} is not registed.", name))) } // TODO // Maybe need to update the obj from HashMap } - DialectItem::Obj(mut o) => { - self.update_pydialect(&mut o); - Ok(o) - } + DialectItem::Obj(o) => Ok(self.update_pydialect(*o)), DialectItem::None => { let g = GLOBAL_HASHMAP.lock(); - let mut res = *g.get("excel").unwrap(); - self.update_pydialect(&mut res); - Ok(res) + let res = *g.get("excel").unwrap(); + Ok(self.update_pydialect(res)) } } } @@ -1001,7 +995,7 @@ mod _csv { csv_core::ReadRecordResult::OutputEndsFull => resize_buf(output_ends), csv_core::ReadRecordResult::Record => break, csv_core::ReadRecordResult::End => { - return Ok(PyIterReturn::StopIteration(None)) + return Ok(PyIterReturn::StopIteration(None)); } } } diff --git a/stdlib/src/dis.rs b/stdlib/src/dis.rs index 12c2ea75df..69767ffbba 100644 --- a/stdlib/src/dis.rs +++ b/stdlib/src/dis.rs @@ -3,9 +3,9 @@ pub(crate) use decl::make_module; #[pymodule(name = "dis")] mod decl { use crate::vm::{ + PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyCode, PyDictRef, PyStrRef}, bytecode::CodeFlags, - PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine, }; #[pyfunction] diff --git a/stdlib/src/faulthandler.rs b/stdlib/src/faulthandler.rs index 2fb93ecc88..9ffd931291 100644 --- a/stdlib/src/faulthandler.rs +++ b/stdlib/src/faulthandler.rs @@ -2,7 +2,7 @@ pub(crate) use decl::make_module; #[pymodule(name = "faulthandler")] mod decl { - use crate::vm::{frame::Frame, function::OptionalArg, stdlib::sys::PyStderr, VirtualMachine}; + use crate::vm::{VirtualMachine, frame::Frame, function::OptionalArg, stdlib::sys::PyStderr}; fn dump_frame(frame: &Frame, vm: &VirtualMachine) { let stderr = PyStderr(vm); diff --git a/stdlib/src/fcntl.rs b/stdlib/src/fcntl.rs index ee73e50397..307d6e4351 100644 --- a/stdlib/src/fcntl.rs +++ b/stdlib/src/fcntl.rs @@ -3,10 +3,10 @@ pub(crate) use fcntl::make_module; #[pymodule] mod fcntl { use crate::vm::{ + PyResult, VirtualMachine, builtins::PyIntRef, function::{ArgMemoryBuffer, ArgStrOrBytesLike, Either, OptionalArg}, stdlib::{io, os}, - PyResult, VirtualMachine, }; // TODO: supply these from (please file an issue/PR upstream): @@ -20,7 +20,7 @@ mod fcntl { // I_LINK, I_UNLINK, I_PLINK, I_PUNLINK #[pyattr] - use libc::{FD_CLOEXEC, F_GETFD, F_GETFL, F_SETFD, F_SETFL}; + use libc::{F_GETFD, F_GETFL, F_SETFD, F_SETFL, FD_CLOEXEC}; #[cfg(not(target_os = "wasi"))] #[pyattr] @@ -45,7 +45,7 @@ mod fcntl { #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] #[pyattr] use libc::{ - F_ADD_SEALS, F_GETLEASE, F_GETPIPE_SZ, F_GET_SEALS, F_NOTIFY, F_SEAL_GROW, F_SEAL_SEAL, + F_ADD_SEALS, F_GET_SEALS, F_GETLEASE, F_GETPIPE_SZ, F_NOTIFY, F_SEAL_GROW, F_SEAL_SEAL, F_SEAL_SHRINK, F_SEAL_WRITE, F_SETLEASE, F_SETPIPE_SZ, }; diff --git a/stdlib/src/gc.rs b/stdlib/src/gc.rs index c78eea9c29..6e906ebab2 100644 --- a/stdlib/src/gc.rs +++ b/stdlib/src/gc.rs @@ -2,7 +2,7 @@ pub(crate) use gc::make_module; #[pymodule] mod gc { - use crate::vm::{function::FuncArgs, PyResult, VirtualMachine}; + use crate::vm::{PyResult, VirtualMachine, function::FuncArgs}; #[pyfunction] fn collect(_args: FuncArgs, _vm: &VirtualMachine) -> i32 { diff --git a/stdlib/src/grp.rs b/stdlib/src/grp.rs index d3eb0848bb..2cdad56588 100644 --- a/stdlib/src/grp.rs +++ b/stdlib/src/grp.rs @@ -3,11 +3,11 @@ pub(crate) use grp::make_module; #[pymodule] mod grp { use crate::vm::{ + PyObjectRef, PyResult, VirtualMachine, builtins::{PyIntRef, PyListRef, PyStrRef}, convert::{IntoPyException, ToPyObject}, exceptions, types::PyStructSequence, - PyObjectRef, PyResult, VirtualMachine, }; use nix::unistd; use std::ptr::NonNull; diff --git a/stdlib/src/hashlib.rs b/stdlib/src/hashlib.rs index 6944c37f9d..6124b6d242 100644 --- a/stdlib/src/hashlib.rs +++ b/stdlib/src/hashlib.rs @@ -6,16 +6,16 @@ pub(crate) use _hashlib::make_module; pub mod _hashlib { use crate::common::lock::PyRwLock; use crate::vm::{ + PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyBytes, PyStrRef, PyTypeRef}, convert::ToPyObject, function::{ArgBytesLike, ArgStrOrBytesLike, FuncArgs, OptionalArg}, protocol::PyBuffer, - PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use blake2::{Blake2b512, Blake2s256}; - use digest::{core_api::BlockSizeUser, DynDigest}; + use digest::{DynDigest, core_api::BlockSizeUser}; use digest::{ExtendableOutput, Update}; - use dyn_clone::{clone_trait_object, DynClone}; + use dyn_clone::{DynClone, clone_trait_object}; use md5::Md5; use sha1::Sha1; use sha2::{Sha224, Sha256, Sha384, Sha512}; diff --git a/stdlib/src/json.rs b/stdlib/src/json.rs index 921e545e5d..aaac0b8bef 100644 --- a/stdlib/src/json.rs +++ b/stdlib/src/json.rs @@ -5,12 +5,12 @@ mod machinery; mod _json { use super::machinery; use crate::vm::{ + AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef}, convert::{ToPyObject, ToPyResult}, function::{IntoFuncArgs, OptionalArg}, protocol::PyIterReturn, types::{Callable, Constructor}, - AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use malachite_bigint::BigInt; use std::str::FromStr; @@ -80,14 +80,14 @@ mod _json { None => { return Ok(PyIterReturn::StopIteration(Some( vm.ctx.new_int(idx).into(), - ))) + ))); } }; let next_idx = idx + c.len_utf8(); match c { '"' => { return scanstring(pystr, next_idx, OptionalArg::Present(self.strict), vm) - .map(|x| PyIterReturn::Return(x.to_pyobject(vm))) + .map(|x| PyIterReturn::Return(x.to_pyobject(vm))); } '{' => { // TODO: parse the object in rust diff --git a/stdlib/src/json/machinery.rs b/stdlib/src/json/machinery.rs index fc6b530866..0614314f4f 100644 --- a/stdlib/src/json/machinery.rs +++ b/stdlib/src/json/machinery.rs @@ -135,7 +135,7 @@ pub fn scanstring<'a>( }; let unterminated_err = || DecodeError::new("Unterminated string starting at", end - 1); let mut chars = s.char_indices().enumerate().skip(end).peekable(); - let (_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?; + let &(_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?; while let Some((char_i, (byte_i, c))) = chars.next() { match c { '"' => { diff --git a/stdlib/src/locale.rs b/stdlib/src/locale.rs index bbe1008e53..9ca71a0957 100644 --- a/stdlib/src/locale.rs +++ b/stdlib/src/locale.rs @@ -30,7 +30,7 @@ struct lconv { } #[cfg(windows)] -extern "C" { +unsafe extern "C" { fn localeconv() -> *mut lconv; } @@ -40,10 +40,10 @@ use libc::localeconv; #[pymodule] mod _locale { use rustpython_vm::{ + PyObjectRef, PyResult, VirtualMachine, builtins::{PyDictRef, PyIntRef, PyListRef, PyStrRef, PyTypeRef}, convert::ToPyException, function::OptionalArg, - PyObjectRef, PyResult, VirtualMachine, }; use std::{ ffi::{CStr, CString}, @@ -56,12 +56,12 @@ mod _locale { ))] #[pyattr] use libc::{ - ABDAY_1, ABDAY_2, ABDAY_3, ABDAY_4, ABDAY_5, ABDAY_6, ABDAY_7, ABMON_1, ABMON_10, ABMON_11, - ABMON_12, ABMON_2, ABMON_3, ABMON_4, ABMON_5, ABMON_6, ABMON_7, ABMON_8, ABMON_9, - ALT_DIGITS, AM_STR, CODESET, CRNCYSTR, DAY_1, DAY_2, DAY_3, DAY_4, DAY_5, DAY_6, DAY_7, - D_FMT, D_T_FMT, ERA, ERA_D_FMT, ERA_D_T_FMT, ERA_T_FMT, LC_MESSAGES, MON_1, MON_10, MON_11, - MON_12, MON_2, MON_3, MON_4, MON_5, MON_6, MON_7, MON_8, MON_9, NOEXPR, PM_STR, RADIXCHAR, - THOUSEP, T_FMT, T_FMT_AMPM, YESEXPR, + ABDAY_1, ABDAY_2, ABDAY_3, ABDAY_4, ABDAY_5, ABDAY_6, ABDAY_7, ABMON_1, ABMON_2, ABMON_3, + ABMON_4, ABMON_5, ABMON_6, ABMON_7, ABMON_8, ABMON_9, ABMON_10, ABMON_11, ABMON_12, + ALT_DIGITS, AM_STR, CODESET, CRNCYSTR, D_FMT, D_T_FMT, DAY_1, DAY_2, DAY_3, DAY_4, DAY_5, + DAY_6, DAY_7, ERA, ERA_D_FMT, ERA_D_T_FMT, ERA_T_FMT, LC_MESSAGES, MON_1, MON_2, MON_3, + MON_4, MON_5, MON_6, MON_7, MON_8, MON_9, MON_10, MON_11, MON_12, NOEXPR, PM_STR, + RADIXCHAR, T_FMT, T_FMT_AMPM, THOUSEP, YESEXPR, }; #[pyattr] @@ -78,11 +78,13 @@ mod _locale { return vm.ctx.new_list(group_vec); } - let mut ptr = group; - while ![0, libc::c_char::MAX].contains(&*ptr) { - let val = vm.ctx.new_int(*ptr); - group_vec.push(val.into()); - ptr = ptr.add(1); + unsafe { + let mut ptr = group; + while ![0, libc::c_char::MAX].contains(&*ptr) { + let val = vm.ctx.new_int(*ptr); + group_vec.push(val.into()); + ptr = ptr.add(1); + } } // https://github.com/python/cpython/blob/677320348728ce058fa3579017e985af74a236d4/Modules/_localemodule.c#L80 if !group_vec.is_empty() { @@ -146,9 +148,7 @@ mod _locale { } macro_rules! set_int_field { - ($lc:expr, $field:ident) => {{ - result.set_item(stringify!($field), vm.new_pyobj((*$lc).$field), vm)? - }}; + ($lc:expr, $field:ident) => {{ result.set_item(stringify!($field), vm.new_pyobj((*$lc).$field), vm)? }}; } macro_rules! set_group_field { diff --git a/stdlib/src/math.rs b/stdlib/src/math.rs index c1abc1a6f2..f86ebb591e 100644 --- a/stdlib/src/math.rs +++ b/stdlib/src/math.rs @@ -3,9 +3,10 @@ pub(crate) use math::make_module; #[pymodule] mod math { use crate::vm::{ - builtins::{try_bigint_to_f64, try_f64_to_bigint, PyFloat, PyInt, PyIntRef, PyStrInterned}, + PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine, + builtins::{PyFloat, PyInt, PyIntRef, PyStrInterned, try_bigint_to_f64, try_f64_to_bigint}, function::{ArgIndex, ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs}, - identifier, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine, + identifier, }; use itertools::Itertools; use malachite_bigint::BigInt; @@ -132,6 +133,9 @@ mod math { #[pyfunction] fn log(x: PyObjectRef, base: OptionalArg, vm: &VirtualMachine) -> PyResult { let base = base.map(|b| *b).unwrap_or(std::f64::consts::E); + if base.is_sign_negative() { + return Err(vm.new_value_error("math domain error".to_owned())); + } log2(x, vm).map(|logx| logx / base.log2()) } @@ -192,16 +196,18 @@ mod math { let x = *x; let y = *y; - if x < 0.0 && x.is_finite() && y.fract() != 0.0 && y.is_finite() { - return Err(vm.new_value_error("math domain error".to_owned())); - } - - if x == 0.0 && y < 0.0 && y != f64::NEG_INFINITY { + if x < 0.0 && x.is_finite() && y.fract() != 0.0 && y.is_finite() + || x == 0.0 && y < 0.0 && y != f64::NEG_INFINITY + { return Err(vm.new_value_error("math domain error".to_owned())); } let value = x.powf(y); + if x.is_finite() && y.is_finite() && value.is_infinite() { + return Err(vm.new_overflow_error("math range error".to_string())); + } + Ok(value) } @@ -212,6 +218,9 @@ mod math { return Ok(value); } if value.is_sign_negative() { + if value.is_zero() { + return Ok(-0.0f64); + } return Err(vm.new_value_error("math domain error".to_owned())); } Ok(value.sqrt()) @@ -260,6 +269,9 @@ mod math { #[pyfunction] fn cos(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + if x.is_infinite() { + return Err(vm.new_value_error("math domain error".to_owned())); + } call_math_func!(cos, x, vm) } @@ -345,7 +357,7 @@ mod math { .map(|x| (x / scale).powi(2)) .chain(std::iter::once(-norm * norm)) // Pairwise summation of floats gives less rounding error than a naive sum. - .tree_fold1(std::ops::Add::add) + .tree_reduce(std::ops::Add::add) .expect("expected at least 1 element"); norm = norm + correction / (2.0 * norm); } @@ -394,11 +406,17 @@ mod math { #[pyfunction] fn sin(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + if x.is_infinite() { + return Err(vm.new_value_error("math domain error".to_owned())); + } call_math_func!(sin, x, vm) } #[pyfunction] fn tan(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + if x.is_infinite() { + return Err(vm.new_value_error("math domain error".to_owned())); + } call_math_func!(tan, x, vm) } @@ -458,21 +476,13 @@ mod math { #[pyfunction] fn erf(x: ArgIntoFloat) -> f64 { let x = *x; - if x.is_nan() { - x - } else { - puruspe::erf(x) - } + if x.is_nan() { x } else { puruspe::erf(x) } } #[pyfunction] fn erfc(x: ArgIntoFloat) -> f64 { let x = *x; - if x.is_nan() { - x - } else { - puruspe::erfc(x) - } + if x.is_nan() { x } else { puruspe::erfc(x) } } #[pyfunction] @@ -612,7 +622,7 @@ mod math { #[pyfunction] fn fsum(seq: ArgIterable, vm: &VirtualMachine) -> PyResult { - let mut partials = vec![]; + let mut partials = Vec::with_capacity(32); let mut special_sum = 0.0; let mut inf_sum = 0.0; @@ -620,11 +630,11 @@ mod math { let mut x = *obj?; let xsave = x; - let mut j = 0; + let mut i = 0; // This inner loop applies `hi`/`lo` summation to each // partial so that the list of partial sums remains exact. - for i in 0..partials.len() { - let mut y: f64 = partials[i]; + for j in 0..partials.len() { + let mut y: f64 = partials[j]; if x.abs() < y.abs() { std::mem::swap(&mut x, &mut y); } @@ -633,33 +643,33 @@ mod math { let hi = x + y; let lo = y - (hi - x); if lo != 0.0 { - partials[j] = lo; - j += 1; + partials[i] = lo; + i += 1; } x = hi; } - if !x.is_finite() { - // a nonfinite x could arise either as - // a result of intermediate overflow, or - // as a result of a nan or inf in the - // summands - if xsave.is_finite() { - return Err(vm.new_overflow_error("intermediate overflow in fsum".to_owned())); - } - if xsave.is_infinite() { - inf_sum += xsave; + partials.truncate(i); + if x != 0.0 { + if !x.is_finite() { + // a nonfinite x could arise either as + // a result of intermediate overflow, or + // as a result of a nan or inf in the + // summands + if xsave.is_finite() { + return Err( + vm.new_overflow_error("intermediate overflow in fsum".to_owned()) + ); + } + if xsave.is_infinite() { + inf_sum += xsave; + } + special_sum += xsave; + // reset partials + partials.clear(); + } else { + partials.push(x); } - special_sum += xsave; - // reset partials - partials.clear(); - } - - if j >= partials.len() { - partials.push(x); - } else { - partials[j] = x; - partials.truncate(j + 1); } } if special_sum != 0.0 { @@ -814,9 +824,38 @@ mod math { (x.fract(), x.trunc()) } - #[pyfunction] - fn nextafter(x: ArgIntoFloat, y: ArgIntoFloat) -> f64 { - float_ops::nextafter(*x, *y) + #[derive(FromArgs)] + struct NextAfterArgs { + #[pyarg(positional)] + x: ArgIntoFloat, + #[pyarg(positional)] + y: ArgIntoFloat, + #[pyarg(named, optional)] + steps: OptionalArg, + } + + #[pyfunction] + fn nextafter(arg: NextAfterArgs, vm: &VirtualMachine) -> PyResult { + let steps: Option = arg + .steps + .map(|v| v.try_to_primitive(vm)) + .transpose()? + .into_option(); + match steps { + Some(steps) => { + if steps < 0 { + return Err( + vm.new_value_error("steps must be a non-negative integer".to_string()) + ); + } + Ok(float_ops::nextafter_with_steps( + *arg.x, + *arg.y, + steps as u64, + )) + } + None => Ok(float_ops::nextafter(*arg.x, *arg.y)), + } } #[pyfunction] @@ -900,10 +939,38 @@ mod math { // refer: https://github.com/python/cpython/blob/main/Modules/mathmodule.c#L3093-L3193 for obj in iter.iter(vm)? { let obj = obj?; + result = vm._mul(&result, &obj)?; + } + + Ok(result) + } - result = vm - ._mul(&result, &obj) - .map_err(|_| vm.new_type_error("math type error".to_owned()))?; + #[pyfunction] + fn sumprod( + p: ArgIterable, + q: ArgIterable, + vm: &VirtualMachine, + ) -> PyResult { + let mut p_iter = p.iter(vm)?; + let mut q_iter = q.iter(vm)?; + // We cannot just create a float because the iterator may contain + // anything as long as it supports __add__ and __mul__. + let mut result = vm.new_pyobj(0); + loop { + let m_p = p_iter.next(); + let m_q = q_iter.next(); + match (m_p, m_q) { + (Some(r_p), Some(r_q)) => { + let p = r_p?; + let q = r_q?; + let tmp = vm._mul(&p, &q)?; + result = vm._add(&result, &tmp)?; + } + (None, None) => break, + _ => { + return Err(vm.new_value_error("Inputs are not the same length".to_string())); + } + } } Ok(result) diff --git a/stdlib/src/md5.rs b/stdlib/src/md5.rs index 833d217f5b..dca48242bb 100644 --- a/stdlib/src/md5.rs +++ b/stdlib/src/md5.rs @@ -2,7 +2,7 @@ pub(crate) use _md5::make_module; #[pymodule] mod _md5 { - use crate::hashlib::_hashlib::{local_md5, HashArgs}; + use crate::hashlib::_hashlib::{HashArgs, local_md5}; use crate::vm::{PyPayload, PyResult, VirtualMachine}; #[pyfunction] diff --git a/stdlib/src/mmap.rs b/stdlib/src/mmap.rs index 8b657532c4..e96339c370 100644 --- a/stdlib/src/mmap.rs +++ b/stdlib/src/mmap.rs @@ -8,7 +8,8 @@ mod mmap { lock::{MapImmutable, PyMutex, PyMutexGuard}, }; use crate::vm::{ - atomic_func, + AsObject, FromArgs, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromBorrowedObject, VirtualMachine, atomic_func, builtins::{PyBytes, PyBytesRef, PyInt, PyIntRef, PyTypeRef}, byte::{bytes_from_object, value_from_object}, convert::ToPyException, @@ -18,8 +19,6 @@ mod mmap { }, sliceable::{SaturatedSlice, SequenceIndex, SequenceIndexOp}, types::{AsBuffer, AsMapping, AsSequence, Constructor, Representable}, - AsObject, FromArgs, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromBorrowedObject, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use memmap2::{Advice, Mmap, MmapMut, MmapOptions}; @@ -497,8 +496,8 @@ mod mmap { fn as_bytes(&self) -> BorrowedValue<[u8]> { PyMutexGuard::map_immutable(self.mmap.lock(), |m| { match m.as_ref().expect("mmap closed or invalid") { - MmapObj::Read(ref mmap) => &mmap[..], - MmapObj::Write(ref mmap) => &mmap[..], + MmapObj::Read(mmap) => &mmap[..], + MmapObj::Write(mmap) => &mmap[..], } }) .into() diff --git a/stdlib/src/multiprocessing.rs b/stdlib/src/multiprocessing.rs index a6d902eb63..2db922e16b 100644 --- a/stdlib/src/multiprocessing.rs +++ b/stdlib/src/multiprocessing.rs @@ -3,7 +3,7 @@ pub(crate) use _multiprocessing::make_module; #[cfg(windows)] #[pymodule] mod _multiprocessing { - use crate::vm::{function::ArgBytesLike, stdlib::os, PyResult, VirtualMachine}; + use crate::vm::{PyResult, VirtualMachine, function::ArgBytesLike, stdlib::os}; use windows_sys::Win32::Networking::WinSock::{self, SOCKET}; #[pyfunction] diff --git a/stdlib/src/overlapped.rs b/stdlib/src/overlapped.rs index 9d08d88bcd..45eac5f51b 100644 --- a/stdlib/src/overlapped.rs +++ b/stdlib/src/overlapped.rs @@ -6,13 +6,13 @@ mod _overlapped { // straight-forward port of Modules/overlapped.c use crate::vm::{ + Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytesRef, PyTypeRef}, common::lock::PyMutex, convert::{ToPyException, ToPyObject}, protocol::PyBuffer, stdlib::os::errno_err, types::Constructor, - Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use windows_sys::Win32::{ Foundation::{self, GetLastError, HANDLE}, diff --git a/stdlib/src/posixsubprocess.rs b/stdlib/src/posixsubprocess.rs index 64e897b607..cff00a70aa 100644 --- a/stdlib/src/posixsubprocess.rs +++ b/stdlib/src/posixsubprocess.rs @@ -28,7 +28,7 @@ mod _posixsubprocess { use rustpython_vm::{AsObject, TryFromBorrowedObject}; use super::*; - use crate::vm::{convert::IntoPyException, PyResult, VirtualMachine}; + use crate::vm::{PyResult, VirtualMachine, convert::IntoPyException}; #[pyfunction] fn fork_exec(args: ForkExecArgs, vm: &VirtualMachine) -> PyResult { diff --git a/stdlib/src/pyexpat.rs b/stdlib/src/pyexpat.rs index 89267d3f7e..3cfe048f17 100644 --- a/stdlib/src/pyexpat.rs +++ b/stdlib/src/pyexpat.rs @@ -3,7 +3,7 @@ * */ -use crate::vm::{builtins::PyModule, extend_module, PyRef, VirtualMachine}; +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule, extend_module}; pub fn make_module(vm: &VirtualMachine) -> PyRef { let module = _pyexpat::make_module(vm); @@ -32,10 +32,10 @@ macro_rules! create_property { #[pymodule(name = "pyexpat")] mod _pyexpat { use crate::vm::{ + Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyStr, PyStrRef, PyType}, function::ArgBytesLike, function::{IntoFuncArgs, OptionalArg}, - Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use rustpython_common::lock::PyRwLock; use std::io::Cursor; diff --git a/stdlib/src/pystruct.rs b/stdlib/src/pystruct.rs index 2d83e9570d..f8d41414f7 100644 --- a/stdlib/src/pystruct.rs +++ b/stdlib/src/pystruct.rs @@ -10,13 +10,13 @@ pub(crate) use _struct::make_module; #[pymodule] pub(crate) mod _struct { use crate::vm::{ - buffer::{new_struct_error, struct_error_type, FormatSpec}, + AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, + buffer::{FormatSpec, new_struct_error, struct_error_type}, builtins::{PyBytes, PyStr, PyStrRef, PyTupleRef, PyTypeRef}, function::{ArgBytesLike, ArgMemoryBuffer, PosArgs}, match_class, protocol::PyIterReturn, types::{Constructor, IterNext, Iterable, SelfIter}, - AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; diff --git a/stdlib/src/random.rs b/stdlib/src/random.rs index 1dfc4fcc30..685c0ae8b9 100644 --- a/stdlib/src/random.rs +++ b/stdlib/src/random.rs @@ -6,79 +6,37 @@ pub(crate) use _random::make_module; mod _random { use crate::common::lock::PyMutex; use crate::vm::{ - builtins::{PyInt, PyTypeRef}, + PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyInt, PyTupleRef}, + convert::ToPyException, function::OptionalOption, - types::Constructor, - PyObjectRef, PyPayload, PyResult, VirtualMachine, + types::{Constructor, Initializer}, }; + use itertools::Itertools; use malachite_bigint::{BigInt, BigUint, Sign}; + use mt19937::MT19937; use num_traits::{Signed, Zero}; - use rand::{rngs::StdRng, RngCore, SeedableRng}; - - #[derive(Debug)] - enum PyRng { - Std(Box), - MT(Box), - } - - impl Default for PyRng { - fn default() -> Self { - PyRng::Std(Box::new(StdRng::from_entropy())) - } - } - - impl RngCore for PyRng { - fn next_u32(&mut self) -> u32 { - match self { - Self::Std(s) => s.next_u32(), - Self::MT(m) => m.next_u32(), - } - } - fn next_u64(&mut self) -> u64 { - match self { - Self::Std(s) => s.next_u64(), - Self::MT(m) => m.next_u64(), - } - } - fn fill_bytes(&mut self, dest: &mut [u8]) { - match self { - Self::Std(s) => s.fill_bytes(dest), - Self::MT(m) => m.fill_bytes(dest), - } - } - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { - match self { - Self::Std(s) => s.try_fill_bytes(dest), - Self::MT(m) => m.try_fill_bytes(dest), - } - } - } + use rand::{RngCore, SeedableRng}; + use rustpython_vm::types::DefaultConstructor; #[pyattr] #[pyclass(name = "Random")] - #[derive(Debug, PyPayload)] + #[derive(Debug, PyPayload, Default)] struct PyRandom { - rng: PyMutex, + rng: PyMutex, } - impl Constructor for PyRandom { - type Args = OptionalOption; + impl DefaultConstructor for PyRandom {} - fn py_new( - cls: PyTypeRef, - // TODO: use x as the seed. - _x: Self::Args, - vm: &VirtualMachine, - ) -> PyResult { - PyRandom { - rng: PyMutex::default(), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + impl Initializer for PyRandom { + type Args = OptionalOption; + + fn init(zelf: PyRef, x: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + zelf.seed(x, vm) } } - #[pyclass(flags(BASETYPE), with(Constructor))] + #[pyclass(flags(BASETYPE), with(Constructor, Initializer))] impl PyRandom { #[pymethod] fn random(&self) -> f64 { @@ -88,9 +46,8 @@ mod _random { #[pymethod] fn seed(&self, n: OptionalOption, vm: &VirtualMachine) -> PyResult<()> { - let new_rng = n - .flatten() - .map(|n| { + *self.rng.lock() = match n.flatten() { + Some(n) => { // Fallback to using hash if object isn't Int-like. let (_, mut key) = match n.downcast::() { Ok(n) => n.as_bigint().abs(), @@ -101,34 +58,24 @@ mod _random { key.reverse(); } let key = if key.is_empty() { &[0] } else { key.as_slice() }; - Ok(PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed( - key, - )))) - }) - .transpose()? - .unwrap_or_default(); - - *self.rng.lock() = new_rng; + MT19937::new_with_slice_seed(key) + } + None => MT19937::try_from_os_rng() + .map_err(|e| std::io::Error::from(e).to_pyexception(vm))?, + }; Ok(()) } #[pymethod] fn getrandbits(&self, k: isize, vm: &VirtualMachine) -> PyResult { match k { - k if k < 0 => { - Err(vm.new_value_error("number of bits must be non-negative".to_owned())) - } + ..0 => Err(vm.new_value_error("number of bits must be non-negative".to_owned())), 0 => Ok(BigInt::zero()), - _ => { + mut k => { let mut rng = self.rng.lock(); - let mut k = k; let mut gen_u32 = |k| { let r = rng.next_u32(); - if k < 32 { - r >> (32 - k) - } else { - r - } + if k < 32 { r >> (32 - k) } else { r } }; let words = (k - 1) / 32 + 1; @@ -151,5 +98,40 @@ mod _random { } } } + + #[pymethod] + fn getstate(&self, vm: &VirtualMachine) -> PyTupleRef { + let rng = self.rng.lock(); + vm.new_tuple( + rng.get_state() + .iter() + .copied() + .chain([rng.get_index() as u32]) + .map(|i| vm.ctx.new_int(i).into()) + .collect::>(), + ) + } + + #[pymethod] + fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { + let state: &[_; mt19937::N + 1] = state + .as_slice() + .try_into() + .map_err(|_| vm.new_value_error("state vector is the wrong size".to_owned()))?; + let (index, state) = state.split_last().unwrap(); + let index: usize = index.try_to_value(vm)?; + if index > mt19937::N { + return Err(vm.new_value_error("invalid state".to_owned())); + } + let state: [u32; mt19937::N] = state + .iter() + .map(|i| i.try_to_value(vm)) + .process_results(|it| it.collect_array())? + .unwrap(); + let mut rng = self.rng.lock(); + rng.set_state(&state); + rng.set_index(index); + Ok(()) + } } } diff --git a/stdlib/src/resource.rs b/stdlib/src/resource.rs index c1b74f2b67..e103cce779 100644 --- a/stdlib/src/resource.rs +++ b/stdlib/src/resource.rs @@ -3,10 +3,10 @@ pub(crate) use resource::make_module; #[pymodule] mod resource { use crate::vm::{ + PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, convert::{ToPyException, ToPyObject}, stdlib::os, types::PyStructSequence, - PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; use std::{io, mem}; @@ -24,8 +24,8 @@ mod resource { // TODO: RLIMIT_OFILE, #[pyattr] use libc::{ - RLIMIT_AS, RLIMIT_CORE, RLIMIT_CPU, RLIMIT_DATA, RLIMIT_FSIZE, RLIMIT_MEMLOCK, - RLIMIT_NOFILE, RLIMIT_NPROC, RLIMIT_RSS, RLIMIT_STACK, RLIM_INFINITY, + RLIM_INFINITY, RLIMIT_AS, RLIMIT_CORE, RLIMIT_CPU, RLIMIT_DATA, RLIMIT_FSIZE, + RLIMIT_MEMLOCK, RLIMIT_NOFILE, RLIMIT_NPROC, RLIMIT_RSS, RLIMIT_STACK, }; #[cfg(any(target_os = "linux", target_os = "android", target_os = "emscripten"))] diff --git a/stdlib/src/scproxy.rs b/stdlib/src/scproxy.rs index 7108f50d8f..9bf29626ab 100644 --- a/stdlib/src/scproxy.rs +++ b/stdlib/src/scproxy.rs @@ -5,9 +5,9 @@ mod _scproxy { // straight-forward port of Modules/_scproxy.c use crate::vm::{ + PyResult, VirtualMachine, builtins::{PyDictRef, PyStr}, convert::ToPyObject, - PyResult, VirtualMachine, }; use system_configuration::core_foundation::{ array::CFArray, diff --git a/stdlib/src/select.rs b/stdlib/src/select.rs index af76d86c8a..4003856bb9 100644 --- a/stdlib/src/select.rs +++ b/stdlib/src/select.rs @@ -1,6 +1,6 @@ use crate::vm::{ - builtins::PyListRef, builtins::PyModule, PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, - VirtualMachine, + PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::PyListRef, + builtins::PyModule, }; use std::{io, mem}; @@ -19,7 +19,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[cfg(unix)] mod platform { - pub use libc::{fd_set, select, timeval, FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO}; + pub use libc::{FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO, fd_set, select, timeval}; pub use std::os::unix::io::RawFd; pub fn check_err(x: i32) -> bool { @@ -30,34 +30,36 @@ mod platform { #[allow(non_snake_case)] #[cfg(windows)] mod platform { + pub use WinSock::{FD_SET as fd_set, FD_SETSIZE, SOCKET as RawFd, TIMEVAL as timeval, select}; use windows_sys::Win32::Networking::WinSock; - pub use WinSock::{select, FD_SET as fd_set, FD_SETSIZE, SOCKET as RawFd, TIMEVAL as timeval}; // based off winsock2.h: https://gist.github.com/piscisaureus/906386#file-winsock2-h-L128-L141 pub unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) { - let mut slot = (&raw mut (*set).fd_array).cast::(); - let fd_count = (*set).fd_count; - for _ in 0..fd_count { - if *slot == fd { - return; + unsafe { + let mut slot = (&raw mut (*set).fd_array).cast::(); + let fd_count = (*set).fd_count; + for _ in 0..fd_count { + if *slot == fd { + return; + } + slot = slot.add(1); + } + // slot == &fd_array[fd_count] at this point + if fd_count < FD_SETSIZE { + *slot = fd as RawFd; + (*set).fd_count += 1; } - slot = slot.add(1); - } - // slot == &fd_array[fd_count] at this point - if fd_count < FD_SETSIZE { - *slot = fd as RawFd; - (*set).fd_count += 1; } } pub unsafe fn FD_ZERO(set: *mut fd_set) { - (*set).fd_count = 0; + unsafe { (*set).fd_count = 0 }; } pub unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool { use WinSock::__WSAFDIsSet; - __WSAFDIsSet(fd as _, set) != 0 + unsafe { __WSAFDIsSet(fd as _, set) != 0 } } pub fn check_err(x: i32) -> bool { @@ -67,7 +69,7 @@ mod platform { #[cfg(target_os = "wasi")] mod platform { - pub use libc::{timeval, FD_SETSIZE}; + pub use libc::{FD_SETSIZE, timeval}; pub use std::os::wasi::io::RawFd; pub fn check_err(x: i32) -> bool { @@ -82,7 +84,7 @@ mod platform { #[allow(non_snake_case)] pub unsafe fn FD_ISSET(fd: RawFd, set: *const fd_set) -> bool { - let set = &*set; + let set = unsafe { &*set }; let n = set.__nfds; for p in &set.__fds[..n] { if *p == fd { @@ -94,7 +96,7 @@ mod platform { #[allow(non_snake_case)] pub unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) { - let set = &mut *set; + let set = unsafe { &mut *set }; let n = set.__nfds; for p in &set.__fds[..n] { if *p == fd { @@ -107,11 +109,11 @@ mod platform { #[allow(non_snake_case)] pub unsafe fn FD_ZERO(set: *mut fd_set) { - let set = &mut *set; + let set = unsafe { &mut *set }; set.__nfds = 0; } - extern "C" { + unsafe extern "C" { pub fn select( nfds: libc::c_int, readfds: *mut fd_set, @@ -122,8 +124,8 @@ mod platform { } } -pub use platform::timeval; use platform::RawFd; +pub use platform::timeval; #[derive(Traverse)] struct Selectable { @@ -216,11 +218,11 @@ fn sec_to_timeval(sec: f64) -> timeval { mod decl { use super::*; use crate::vm::{ + PyObjectRef, PyResult, VirtualMachine, builtins::PyTypeRef, convert::ToPyException, function::{Either, OptionalOption}, stdlib::time, - PyObjectRef, PyResult, VirtualMachine, }; #[pyattr] @@ -325,12 +327,12 @@ mod decl { pub(super) mod poll { use super::*; use crate::vm::{ + AsObject, PyPayload, builtins::PyFloat, common::lock::PyMutex, convert::{IntoPyException, ToPyObject}, function::OptionalArg, stdlib::io::Fildes, - AsObject, PyPayload, }; use libc::pollfd; use num_traits::{Signed, ToPrimitive}; @@ -492,8 +494,9 @@ mod decl { #[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))] #[pyattr] use libc::{ - EPOLLERR, EPOLLEXCLUSIVE, EPOLLHUP, EPOLLIN, EPOLLMSG, EPOLLONESHOT, EPOLLOUT, EPOLLPRI, - EPOLLRDBAND, EPOLLRDHUP, EPOLLRDNORM, EPOLLWAKEUP, EPOLLWRBAND, EPOLLWRNORM, EPOLL_CLOEXEC, + EPOLL_CLOEXEC, EPOLLERR, EPOLLEXCLUSIVE, EPOLLHUP, EPOLLIN, EPOLLMSG, EPOLLONESHOT, + EPOLLOUT, EPOLLPRI, EPOLLRDBAND, EPOLLRDHUP, EPOLLRDNORM, EPOLLWAKEUP, EPOLLWRBAND, + EPOLLWRNORM, }; #[cfg(any(target_os = "linux", target_os = "android", target_os = "redox"))] #[pyattr] @@ -503,13 +506,13 @@ mod decl { pub(super) mod epoll { use super::*; use crate::vm::{ + PyPayload, builtins::PyTypeRef, common::lock::{PyRwLock, PyRwLockReadGuard}, convert::{IntoPyException, ToPyObject}, function::OptionalArg, stdlib::io::Fildes, types::Constructor, - PyPayload, }; use rustix::event::epoll::{self, EventData, EventFlags}; use std::ops::Deref; @@ -643,7 +646,7 @@ mod decl { ..-1 => { return Err(vm.new_value_error(format!( "maxevents must be greater than 0, got {maxevents}" - ))) + ))); } -1 => libc::FD_SETSIZE - 1, _ => maxevents as usize, diff --git a/stdlib/src/sha1.rs b/stdlib/src/sha1.rs index 3820e7d96a..04845bb76b 100644 --- a/stdlib/src/sha1.rs +++ b/stdlib/src/sha1.rs @@ -2,7 +2,7 @@ pub(crate) use _sha1::make_module; #[pymodule] mod _sha1 { - use crate::hashlib::_hashlib::{local_sha1, HashArgs}; + use crate::hashlib::_hashlib::{HashArgs, local_sha1}; use crate::vm::{PyPayload, PyResult, VirtualMachine}; #[pyfunction] diff --git a/stdlib/src/sha256.rs b/stdlib/src/sha256.rs index cae0172666..5d031968ae 100644 --- a/stdlib/src/sha256.rs +++ b/stdlib/src/sha256.rs @@ -1,4 +1,4 @@ -use crate::vm::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let _ = vm.import("_hashlib", 0); @@ -7,7 +7,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule] mod _sha256 { - use crate::hashlib::_hashlib::{local_sha224, local_sha256, HashArgs}; + use crate::hashlib::_hashlib::{HashArgs, local_sha224, local_sha256}; use crate::vm::{PyPayload, PyResult, VirtualMachine}; #[pyfunction] diff --git a/stdlib/src/sha3.rs b/stdlib/src/sha3.rs index f0c1c5ef69..07b61d9aed 100644 --- a/stdlib/src/sha3.rs +++ b/stdlib/src/sha3.rs @@ -3,8 +3,8 @@ pub(crate) use _sha3::make_module; #[pymodule] mod _sha3 { use crate::hashlib::_hashlib::{ - local_sha3_224, local_sha3_256, local_sha3_384, local_sha3_512, local_shake_128, - local_shake_256, HashArgs, + HashArgs, local_sha3_224, local_sha3_256, local_sha3_384, local_sha3_512, local_shake_128, + local_shake_256, }; use crate::vm::{PyPayload, PyResult, VirtualMachine}; diff --git a/stdlib/src/sha512.rs b/stdlib/src/sha512.rs index 8c510fb730..baf63fdacf 100644 --- a/stdlib/src/sha512.rs +++ b/stdlib/src/sha512.rs @@ -1,4 +1,4 @@ -use crate::vm::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let _ = vm.import("_hashlib", 0); @@ -7,7 +7,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule] mod _sha512 { - use crate::hashlib::_hashlib::{local_sha384, local_sha512, HashArgs}; + use crate::hashlib::_hashlib::{HashArgs, local_sha384, local_sha512}; use crate::vm::{PyPayload, PyResult, VirtualMachine}; #[pyfunction] diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index bf2f5ecd30..a38b4f123c 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -1,6 +1,6 @@ -use crate::vm::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule}; #[cfg(feature = "ssl")] -pub(super) use _socket::{sock_select, timeout_error_msg, PySocket, SelectKind}; +pub(super) use _socket::{PySocket, SelectKind, sock_select, timeout_error_msg}; pub fn make_module(vm: &VirtualMachine) -> PyRef { #[cfg(windows)] @@ -12,13 +12,13 @@ pub fn make_module(vm: &VirtualMachine) -> PyRef { mod _socket { use crate::common::lock::{PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard}; use crate::vm::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyListRef, PyStrRef, PyTupleRef, PyTypeRef}, common::os::ErrorExt, convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject}, function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption}, types::{Constructor, DefaultConstructor, Initializer, Representable}, utils::ToCString, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; @@ -40,32 +40,33 @@ mod _socket { INADDR_ANY, INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_NONE, }; pub use winapi::um::winsock2::{ - getprotobyname, getservbyname, getservbyport, getsockopt, setsockopt, - SO_EXCLUSIVEADDRUSE, + SO_EXCLUSIVEADDRUSE, getprotobyname, getservbyname, getservbyport, getsockopt, + setsockopt, }; pub use winapi::um::ws2tcpip::{ EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NODATA, EAI_NONAME, EAI_SERVICE, EAI_SOCKTYPE, }; pub use windows_sys::Win32::Networking::WinSock::{ - AF_DECnet, AF_APPLETALK, AF_IPX, AF_LINK, AI_ADDRCONFIG, AI_ALL, AI_CANONNAME, - AI_NUMERICSERV, AI_V4MAPPED, IPPORT_RESERVED, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, - IPPROTO_ESP, IPPROTO_FRAGMENT, IPPROTO_GGP, IPPROTO_HOPOPTS, IPPROTO_ICMP, - IPPROTO_ICMPV6, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IP, IPPROTO_IP as IPPROTO_IPIP, - IPPROTO_IPV4, IPPROTO_IPV6, IPPROTO_ND, IPPROTO_NONE, IPPROTO_PIM, IPPROTO_PUP, - IPPROTO_RAW, IPPROTO_ROUTING, IPPROTO_TCP, IPPROTO_UDP, IPV6_CHECKSUM, IPV6_DONTFRAG, - IPV6_HOPLIMIT, IPV6_HOPOPTS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, IPV6_MULTICAST_HOPS, + AF_APPLETALK, AF_DECnet, AF_IPX, AF_LINK, AI_ADDRCONFIG, AI_ALL, AI_CANONNAME, + AI_NUMERICSERV, AI_V4MAPPED, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_HDRINCL, + IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_OPTIONS, IP_RECVDSTADDR, + IP_TOS, IP_TTL, IPPORT_RESERVED, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, + IPPROTO_FRAGMENT, IPPROTO_GGP, IPPROTO_HOPOPTS, IPPROTO_ICMP, IPPROTO_ICMPV6, + IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IP, IPPROTO_IP as IPPROTO_IPIP, IPPROTO_IPV4, + IPPROTO_IPV6, IPPROTO_ND, IPPROTO_NONE, IPPROTO_PIM, IPPROTO_PUP, IPPROTO_RAW, + IPPROTO_ROUTING, IPPROTO_TCP, IPPROTO_UDP, IPV6_CHECKSUM, IPV6_DONTFRAG, IPV6_HOPLIMIT, + IPV6_HOPOPTS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_PKTINFO, IPV6_RECVRTHDR, IPV6_RECVTCLASS, - IPV6_RTHDR, IPV6_TCLASS, IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, - IP_DROP_MEMBERSHIP, IP_HDRINCL, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, - IP_OPTIONS, IP_RECVDSTADDR, IP_TOS, IP_TTL, MSG_BCAST, MSG_CTRUNC, MSG_DONTROUTE, - MSG_MCAST, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_MAXSERV, - NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, RCVALL_IPLEVEL, RCVALL_OFF, - RCVALL_ON, RCVALL_SOCKETLEVELONLY, SD_BOTH as SHUT_RDWR, SD_RECEIVE as SHUT_RD, - SD_SEND as SHUT_WR, SIO_KEEPALIVE_VALS, SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, SOCK_DGRAM, - SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET, SOCK_STREAM, SOL_SOCKET, SOMAXCONN, SO_BROADCAST, - SO_ERROR, SO_LINGER, SO_OOBINLINE, SO_REUSEADDR, SO_TYPE, SO_USELOOPBACK, TCP_NODELAY, - WSAEBADF, WSAECONNRESET, WSAENOTSOCK, WSAEWOULDBLOCK, + IPV6_RTHDR, IPV6_TCLASS, IPV6_UNICAST_HOPS, IPV6_V6ONLY, MSG_BCAST, MSG_CTRUNC, + MSG_DONTROUTE, MSG_MCAST, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, + NI_MAXHOST, NI_MAXSERV, NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, + RCVALL_IPLEVEL, RCVALL_OFF, RCVALL_ON, RCVALL_SOCKETLEVELONLY, SD_BOTH as SHUT_RDWR, + SD_RECEIVE as SHUT_RD, SD_SEND as SHUT_WR, SIO_KEEPALIVE_VALS, SIO_LOOPBACK_FAST_PATH, + SIO_RCVALL, SO_BROADCAST, SO_ERROR, SO_LINGER, SO_OOBINLINE, SO_REUSEADDR, SO_TYPE, + SO_USELOOPBACK, SOCK_DGRAM, SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET, SOCK_STREAM, + SOL_SOCKET, SOMAXCONN, TCP_NODELAY, WSAEBADF, WSAECONNRESET, WSAENOTSOCK, + WSAEWOULDBLOCK, }; pub const IF_NAMESIZE: usize = windows_sys::Win32::NetworkManagement::Ndis::IF_MAX_STRING_SIZE as _; @@ -86,14 +87,14 @@ mod _socket { IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPV6, IPPROTO_TCP, IPPROTO_TCP as SOL_TCP, IPPROTO_UDP, MSG_CTRUNC, MSG_DONTROUTE, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, SHUT_RD, SHUT_RDWR, SHUT_WR, - SOCK_DGRAM, SOCK_STREAM, SOL_SOCKET, SO_BROADCAST, SO_ERROR, SO_LINGER, SO_OOBINLINE, - SO_REUSEADDR, SO_TYPE, TCP_NODELAY, + SO_BROADCAST, SO_ERROR, SO_LINGER, SO_OOBINLINE, SO_REUSEADDR, SO_TYPE, SOCK_DGRAM, + SOCK_STREAM, SOL_SOCKET, TCP_NODELAY, }; #[cfg(not(target_os = "redox"))] #[pyattr] use c::{ - AF_DECnet, AF_APPLETALK, AF_IPX, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, + AF_APPLETALK, AF_DECnet, AF_IPX, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, IPPROTO_FRAGMENT, IPPROTO_HOPOPTS, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IPIP, IPPROTO_NONE, IPPROTO_PIM, IPPROTO_PUP, IPPROTO_RAW, IPPROTO_ROUTING, }; @@ -126,8 +127,9 @@ mod _socket { J1939_IDLE_ADDR, J1939_MAX_UNICAST_ADDR, J1939_NLA_BYTES_ACKED, J1939_NLA_PAD, J1939_NO_ADDR, J1939_NO_NAME, J1939_NO_PGN, J1939_PGN_ADDRESS_CLAIMED, J1939_PGN_ADDRESS_COMMANDED, J1939_PGN_MAX, J1939_PGN_PDU1_MAX, J1939_PGN_REQUEST, - SCM_J1939_DEST_ADDR, SCM_J1939_DEST_NAME, SCM_J1939_ERRQUEUE, SCM_J1939_PRIO, SOL_CAN_BASE, - SOL_CAN_RAW, SO_J1939_ERRQUEUE, SO_J1939_FILTER, SO_J1939_PROMISC, SO_J1939_SEND_PRIO, + SCM_J1939_DEST_ADDR, SCM_J1939_DEST_NAME, SCM_J1939_ERRQUEUE, SCM_J1939_PRIO, + SO_J1939_ERRQUEUE, SO_J1939_FILTER, SO_J1939_PROMISC, SO_J1939_SEND_PRIO, SOL_CAN_BASE, + SOL_CAN_RAW, }; #[cfg(all(target_os = "linux", target_env = "gnu"))] @@ -168,11 +170,11 @@ mod _socket { #[pyattr] use c::{ ALG_OP_DECRYPT, ALG_OP_ENCRYPT, ALG_SET_AEAD_ASSOCLEN, ALG_SET_AEAD_AUTHSIZE, ALG_SET_IV, - ALG_SET_KEY, ALG_SET_OP, IPV6_DSTOPTS, IPV6_NEXTHOP, IPV6_PATHMTU, IPV6_RECVDSTOPTS, - IPV6_RECVHOPLIMIT, IPV6_RECVHOPOPTS, IPV6_RECVPATHMTU, IPV6_RTHDRDSTOPTS, - IP_DEFAULT_MULTICAST_LOOP, IP_RECVOPTS, IP_RETOPTS, NETLINK_CRYPTO, NETLINK_DNRTMSG, - NETLINK_FIREWALL, NETLINK_IP6_FW, NETLINK_NFLOG, NETLINK_ROUTE, NETLINK_USERSOCK, - NETLINK_XFRM, SOL_ALG, SO_PASSSEC, SO_PEERSEC, + ALG_SET_KEY, ALG_SET_OP, IP_DEFAULT_MULTICAST_LOOP, IP_RECVOPTS, IP_RETOPTS, IPV6_DSTOPTS, + IPV6_NEXTHOP, IPV6_PATHMTU, IPV6_RECVDSTOPTS, IPV6_RECVHOPLIMIT, IPV6_RECVHOPOPTS, + IPV6_RECVPATHMTU, IPV6_RTHDRDSTOPTS, NETLINK_CRYPTO, NETLINK_DNRTMSG, NETLINK_FIREWALL, + NETLINK_IP6_FW, NETLINK_NFLOG, NETLINK_ROUTE, NETLINK_USERSOCK, NETLINK_XFRM, SO_PASSSEC, + SO_PEERSEC, SOL_ALG, }; #[cfg(any(target_os = "android", target_vendor = "apple"))] @@ -190,9 +192,9 @@ mod _socket { #[cfg(any(unix, target_os = "android", windows))] #[pyattr] use c::{ - INADDR_BROADCAST, IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, - IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_MULTICAST_IF, - IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_TTL, + INADDR_BROADCAST, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_MULTICAST_IF, + IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_TTL, IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, + IPV6_MULTICAST_LOOP, IPV6_UNICAST_HOPS, IPV6_V6ONLY, }; #[cfg(any(unix, target_os = "android", windows))] @@ -213,8 +215,8 @@ mod _socket { AF_ALG, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BRIDGE, AF_CAN, AF_ECONET, AF_IRDA, AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PPPOX, AF_RDS, AF_SECURITY, AF_TIPC, AF_VSOCK, AF_WANPIPE, AF_X25, IP_TRANSPARENT, MSG_CONFIRM, MSG_ERRQUEUE, - MSG_FASTOPEN, MSG_MORE, PF_CAN, PF_PACKET, PF_RDS, SCM_CREDENTIALS, SOL_IP, SOL_TIPC, - SOL_UDP, SO_BINDTODEVICE, SO_MARK, TCP_CORK, TCP_DEFER_ACCEPT, TCP_LINGER2, TCP_QUICKACK, + MSG_FASTOPEN, MSG_MORE, PF_CAN, PF_PACKET, PF_RDS, SCM_CREDENTIALS, SO_BINDTODEVICE, + SO_MARK, SOL_IP, SOL_TIPC, SOL_UDP, TCP_CORK, TCP_DEFER_ACCEPT, TCP_LINGER2, TCP_QUICKACK, TCP_SYNCNT, TCP_WINDOW_CLAMP, }; @@ -271,7 +273,7 @@ mod _socket { #[cfg(any(target_os = "android", target_os = "linux", windows))] #[pyattr] - use c::{IPV6_HOPOPTS, IPV6_RECVRTHDR, IPV6_RTHDR, IP_OPTIONS}; + use c::{IP_OPTIONS, IPV6_HOPOPTS, IPV6_RECVRTHDR, IPV6_RTHDR}; #[cfg(any( target_os = "dragonfly", @@ -525,7 +527,7 @@ mod _socket { ))] #[pyattr] use c::{ - AF_LINK, IPPROTO_GGP, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, IP_RECVDSTADDR, SO_USELOOPBACK, + AF_LINK, IP_RECVDSTADDR, IPPROTO_GGP, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, SO_USELOOPBACK, }; #[cfg(any( @@ -633,7 +635,7 @@ mod _socket { #[pyattr] use c::{ EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NONAME, EAI_SERVICE, - EAI_SOCKTYPE, IPV6_RECVTCLASS, IPV6_TCLASS, IP_HDRINCL, IP_TOS, SOMAXCONN, + EAI_SOCKTYPE, IP_HDRINCL, IP_TOS, IPV6_RECVTCLASS, IPV6_TCLASS, SOMAXCONN, }; #[cfg(not(any( @@ -1473,11 +1475,7 @@ mod _socket { #[pymethod] fn gettimeout(&self) -> Option { let timeout = self.timeout.load(); - if timeout >= 0.0 { - Some(timeout) - } else { - None - } + if timeout >= 0.0 { Some(timeout) } else { None } } #[pymethod] @@ -1601,7 +1599,7 @@ mod _socket { _ => { return Err(vm .new_value_error("`how` must be SHUT_RD, SHUT_WR, or SHUT_RDWR".to_owned()) - .into()) + .into()); } }; Ok(self.sock()?.shutdown(how)?) @@ -1788,7 +1786,7 @@ mod _socket { } unsafe fn slice_as_uninit(v: &mut [T]) -> &mut [MaybeUninit] { - &mut *(v as *mut [T] as *mut [MaybeUninit]) + unsafe { &mut *(v as *mut [T] as *mut [MaybeUninit]) } } enum IoOrPyException { @@ -1924,7 +1922,7 @@ mod _socket { let host = opts.host.as_ref().map(|s| s.as_str()); let port = opts.port.as_ref().map(|p| -> std::borrow::Cow { match p { - Either::A(ref s) => s.as_str().into(), + Either::A(s) => s.as_str().into(), Either::B(i) => i.to_string().into(), } }); @@ -2054,7 +2052,7 @@ mod _socket { _ => { return Err(vm .new_type_error("illegal sockaddr argument".to_owned()) - .into()) + .into()); } } let (addr, flowinfo, scopeid) = Address::from_tuple_ipv6(&address, vm)?; @@ -2259,7 +2257,7 @@ mod _socket { _ => { return Err(vm .new_os_error("address family mismatched".to_owned()) - .into()) + .into()); } } return Ok(SocketAddr::V4(net::SocketAddrV4::new( @@ -2312,12 +2310,12 @@ mod _socket { #[cfg(unix)] { use std::os::unix::io::FromRawFd; - Socket::from_raw_fd(fileno) + unsafe { Socket::from_raw_fd(fileno) } } #[cfg(windows)] { use std::os::windows::io::FromRawSocket; - Socket::from_raw_socket(fileno) + unsafe { Socket::from_raw_socket(fileno) } } } pub(super) fn sock_fileno(sock: &Socket) -> RawSocket { @@ -2433,11 +2431,7 @@ mod _socket { #[pyfunction] fn getdefaulttimeout() -> Option { let timeout = DEFAULT_TIMEOUT.load(); - if timeout >= 0.0 { - Some(timeout) - } else { - None - } + if timeout >= 0.0 { Some(timeout) } else { None } } #[pyfunction] diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index 6cb3deae7d..5487511e20 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -8,7 +8,7 @@ // spell-checker:ignore cantlock commithook foreignkey notnull primarykey gettemppath autoindex convpath // spell-checker:ignore dbmoved vnode nbytes -use rustpython_vm::{builtins::PyModule, AsObject, PyRef, VirtualMachine}; +use rustpython_vm::{AsObject, PyRef, VirtualMachine, builtins::PyModule}; // pub(crate) use _sqlite::make_module; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { @@ -21,29 +21,29 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule] mod _sqlite { use libsqlite3_sys::{ - sqlite3, sqlite3_aggregate_context, sqlite3_backup_finish, sqlite3_backup_init, - sqlite3_backup_pagecount, sqlite3_backup_remaining, sqlite3_backup_step, sqlite3_bind_blob, - sqlite3_bind_double, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count, - sqlite3_bind_parameter_name, sqlite3_bind_text, sqlite3_blob, sqlite3_blob_bytes, - sqlite3_blob_close, sqlite3_blob_open, sqlite3_blob_read, sqlite3_blob_write, - sqlite3_busy_timeout, sqlite3_changes, sqlite3_close_v2, sqlite3_column_blob, - sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_decltype, sqlite3_column_double, - sqlite3_column_int64, sqlite3_column_name, sqlite3_column_text, sqlite3_column_type, - sqlite3_complete, sqlite3_context, sqlite3_context_db_handle, sqlite3_create_collation_v2, - sqlite3_create_function_v2, sqlite3_create_window_function, sqlite3_data_count, - sqlite3_db_handle, sqlite3_errcode, sqlite3_errmsg, sqlite3_exec, sqlite3_expanded_sql, - sqlite3_extended_errcode, sqlite3_finalize, sqlite3_get_autocommit, sqlite3_interrupt, - sqlite3_last_insert_rowid, sqlite3_libversion, sqlite3_limit, sqlite3_open_v2, - sqlite3_prepare_v2, sqlite3_progress_handler, sqlite3_reset, sqlite3_result_blob, - sqlite3_result_double, sqlite3_result_error, sqlite3_result_error_nomem, - sqlite3_result_error_toobig, sqlite3_result_int64, sqlite3_result_null, - sqlite3_result_text, sqlite3_set_authorizer, sqlite3_sleep, sqlite3_step, sqlite3_stmt, - sqlite3_stmt_busy, sqlite3_stmt_readonly, sqlite3_threadsafe, sqlite3_total_changes, - sqlite3_trace_v2, sqlite3_user_data, sqlite3_value, sqlite3_value_blob, - sqlite3_value_bytes, sqlite3_value_double, sqlite3_value_int64, sqlite3_value_text, - sqlite3_value_type, SQLITE_BLOB, SQLITE_DETERMINISTIC, SQLITE_FLOAT, SQLITE_INTEGER, - SQLITE_NULL, SQLITE_OPEN_CREATE, SQLITE_OPEN_READWRITE, SQLITE_OPEN_URI, SQLITE_TEXT, - SQLITE_TRACE_STMT, SQLITE_TRANSIENT, SQLITE_UTF8, + SQLITE_BLOB, SQLITE_DETERMINISTIC, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, + SQLITE_OPEN_CREATE, SQLITE_OPEN_READWRITE, SQLITE_OPEN_URI, SQLITE_TEXT, SQLITE_TRACE_STMT, + SQLITE_TRANSIENT, SQLITE_UTF8, sqlite3, sqlite3_aggregate_context, sqlite3_backup_finish, + sqlite3_backup_init, sqlite3_backup_pagecount, sqlite3_backup_remaining, + sqlite3_backup_step, sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int64, + sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, + sqlite3_bind_text, sqlite3_blob, sqlite3_blob_bytes, sqlite3_blob_close, sqlite3_blob_open, + sqlite3_blob_read, sqlite3_blob_write, sqlite3_busy_timeout, sqlite3_changes, + sqlite3_close_v2, sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_count, + sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int64, sqlite3_column_name, + sqlite3_column_text, sqlite3_column_type, sqlite3_complete, sqlite3_context, + sqlite3_context_db_handle, sqlite3_create_collation_v2, sqlite3_create_function_v2, + sqlite3_create_window_function, sqlite3_data_count, sqlite3_db_handle, sqlite3_errcode, + sqlite3_errmsg, sqlite3_exec, sqlite3_expanded_sql, sqlite3_extended_errcode, + sqlite3_finalize, sqlite3_get_autocommit, sqlite3_interrupt, sqlite3_last_insert_rowid, + sqlite3_libversion, sqlite3_limit, sqlite3_open_v2, sqlite3_prepare_v2, + sqlite3_progress_handler, sqlite3_reset, sqlite3_result_blob, sqlite3_result_double, + sqlite3_result_error, sqlite3_result_error_nomem, sqlite3_result_error_toobig, + sqlite3_result_int64, sqlite3_result_null, sqlite3_result_text, sqlite3_set_authorizer, + sqlite3_sleep, sqlite3_step, sqlite3_stmt, sqlite3_stmt_busy, sqlite3_stmt_readonly, + sqlite3_threadsafe, sqlite3_total_changes, sqlite3_trace_v2, sqlite3_user_data, + sqlite3_value, sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double, + sqlite3_value_int64, sqlite3_value_text, sqlite3_value_type, }; use malachite_bigint::Sign; use rustpython_common::{ @@ -53,13 +53,16 @@ mod _sqlite { static_cell, }; use rustpython_vm::{ - atomic_func, + __exports::paste, + AsObject, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromBorrowedObject, VirtualMachine, atomic_func, builtins::{ PyBaseException, PyBaseExceptionRef, PyByteArray, PyBytes, PyDict, PyDictRef, PyFloat, PyInt, PyIntRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, }, convert::IntoObject, function::{ArgCallable, ArgIterable, FsPath, FuncArgs, OptionalArg, PyComparisonValue}, + object::{Traverse, TraverseFn}, protocol::{PyBuffer, PyIterReturn, PyMappingMethods, PySequence, PySequenceMethods}, sliceable::{SaturatedSliceIter, SliceableSequenceOp}, types::{ @@ -67,16 +70,12 @@ mod _sqlite { PyComparisonOp, SelfIter, }, utils::ToCString, - AsObject, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromBorrowedObject, VirtualMachine, - __exports::paste, - object::{Traverse, TraverseFn}, }; use std::{ - ffi::{c_int, c_longlong, c_uint, c_void, CStr}, + ffi::{CStr, c_int, c_longlong, c_uint, c_void}, fmt::Debug, ops::Deref, - ptr::{null, null_mut}, + ptr::{NonNull, null, null_mut}, thread::ThreadId, }; @@ -381,7 +380,7 @@ mod _sqlite { } struct CallbackData { - obj: *const PyObject, + obj: NonNull, vm: *const VirtualMachine, } @@ -394,11 +393,11 @@ mod _sqlite { } fn retrieve(&self) -> (&PyObject, &VirtualMachine) { - unsafe { (&*self.obj, &*self.vm) } + unsafe { (self.obj.as_ref(), &*self.vm) } } unsafe extern "C" fn destructor(data: *mut c_void) { - drop(Box::from_raw(data.cast::())); + drop(unsafe { Box::from_raw(data.cast::()) }); } unsafe extern "C" fn func_callback( @@ -407,8 +406,8 @@ mod _sqlite { argv: *mut *mut sqlite3_value, ) { let context = SqliteContext::from(context); - let (func, vm) = (*context.user_data::()).retrieve(); - let args = std::slice::from_raw_parts(argv, argc as usize); + let (func, vm) = unsafe { (*context.user_data::()).retrieve() }; + let args = unsafe { std::slice::from_raw_parts(argv, argc as usize) }; let f = || -> PyResult<()> { let db = context.db_handle(); @@ -434,31 +433,31 @@ mod _sqlite { argv: *mut *mut sqlite3_value, ) { let context = SqliteContext::from(context); - let (cls, vm) = (*context.user_data::()).retrieve(); - let args = std::slice::from_raw_parts(argv, argc as usize); + let (cls, vm) = unsafe { (*context.user_data::()).retrieve() }; + let args = unsafe { std::slice::from_raw_parts(argv, argc as usize) }; let instance = context.aggregate_context::<*const PyObject>(); - if (*instance).is_null() { + if unsafe { (*instance).is_null() } { match cls.call((), vm) { - Ok(obj) => *instance = obj.into_raw(), + Ok(obj) => unsafe { *instance = obj.into_raw().as_ptr() }, Err(exc) => { return context.result_exception( vm, exc, "user-defined aggregate's '__init__' method raised error\0", - ) + ); } } } - let instance = &**instance; + let instance = unsafe { &**instance }; Self::call_method_with_args(context, instance, "step", args, vm); } unsafe extern "C" fn finalize_callback(context: *mut sqlite3_context) { let context = SqliteContext::from(context); - let (_, vm) = (*context.user_data::()).retrieve(); + let (_, vm) = unsafe { (*context.user_data::()).retrieve() }; let instance = context.aggregate_context::<*const PyObject>(); - let Some(instance) = (*instance).as_ref() else { + let Some(instance) = (unsafe { (*instance).as_ref() }) else { return; }; @@ -472,7 +471,7 @@ mod _sqlite { b_len: c_int, b_ptr: *const c_void, ) -> c_int { - let (callable, vm) = (*data.cast::()).retrieve(); + let (callable, vm) = unsafe { (*data.cast::()).retrieve() }; let f = || -> PyResult { let text1 = ptr_to_string(a_ptr.cast(), a_len, null_mut(), vm)?; @@ -499,9 +498,9 @@ mod _sqlite { unsafe extern "C" fn value_callback(context: *mut sqlite3_context) { let context = SqliteContext::from(context); - let (_, vm) = (*context.user_data::()).retrieve(); + let (_, vm) = unsafe { (*context.user_data::()).retrieve() }; let instance = context.aggregate_context::<*const PyObject>(); - let instance = &**instance; + let instance = unsafe { &**instance }; Self::callback_result_from_method(context, instance, "value", vm); } @@ -512,10 +511,10 @@ mod _sqlite { argv: *mut *mut sqlite3_value, ) { let context = SqliteContext::from(context); - let (_, vm) = (*context.user_data::()).retrieve(); - let args = std::slice::from_raw_parts(argv, argc as usize); + let (_, vm) = unsafe { (*context.user_data::()).retrieve() }; + let args = unsafe { std::slice::from_raw_parts(argv, argc as usize) }; let instance = context.aggregate_context::<*const PyObject>(); - let instance = &**instance; + let instance = unsafe { &**instance }; Self::call_method_with_args(context, instance, "inverse", args, vm); } @@ -528,7 +527,7 @@ mod _sqlite { db_name: *const libc::c_char, access: *const libc::c_char, ) -> c_int { - let (callable, vm) = (*data.cast::()).retrieve(); + let (callable, vm) = unsafe { (*data.cast::()).retrieve() }; let f = || -> PyResult { let arg1 = ptr_to_str(arg1, vm)?; let arg2 = ptr_to_str(arg2, vm)?; @@ -551,8 +550,8 @@ mod _sqlite { stmt: *mut c_void, sql: *mut c_void, ) -> c_int { - let (callable, vm) = (*data.cast::()).retrieve(); - let expanded = sqlite3_expanded_sql(stmt.cast()); + let (callable, vm) = unsafe { (*data.cast::()).retrieve() }; + let expanded = unsafe { sqlite3_expanded_sql(stmt.cast()) }; let f = || -> PyResult<()> { let stmt = ptr_to_str(expanded, vm).or_else(|_| ptr_to_str(sql.cast(), vm))?; callable.call((stmt,), vm)?; @@ -563,7 +562,7 @@ mod _sqlite { } unsafe extern "C" fn progress_callback(data: *mut c_void) -> c_int { - let (callable, vm) = (*data.cast::()).retrieve(); + let (callable, vm) = unsafe { (*data.cast::()).retrieve() }; if let Ok(val) = callable.call((), vm) { if let Ok(val) = val.is_true(vm) { return val as c_int; @@ -2082,7 +2081,7 @@ mod _sqlite { _ => { return Err(vm.new_value_error( "'origin' should be os.SEEK_SET, os.SEEK_CUR, or os.SEEK_END".to_owned(), - )) + )); } } diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 03222ff4f5..2b8ffc7d8c 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -1,19 +1,32 @@ -use crate::vm::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::vm::{PyRef, VirtualMachine, builtins::PyModule}; +use openssl_probe::ProbeResult; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { - // if openssl is vendored, it doesn't know the locations of system certificates - #[cfg(feature = "ssl-vendor")] - if let None | Some("0") = option_env!("OPENSSL_NO_VENDOR") { - openssl_probe::init_ssl_cert_env_vars(); - } - openssl::init(); + // if openssl is vendored, it doesn't know the locations + // of system certificates - cache the probe result now. + #[cfg(openssl_vendored)] + LazyLock::force(&PROBE); _ssl::make_module(vm) } +// define our own copy of ProbeResult so we can handle the vendor case +// easily, without having to have a bunch of cfgs +cfg_if::cfg_if! { + if #[cfg(openssl_vendored)] { + use std::sync::LazyLock; + static PROBE: LazyLock = LazyLock::new(openssl_probe::probe); + fn probe() -> &'static ProbeResult { &PROBE } + } else { + fn probe() -> &'static ProbeResult { + &ProbeResult { cert_file: None, cert_dir: None } + } + } +} + #[allow(non_upper_case_globals)] #[pymodule(with(ossl101, windows))] mod _ssl { - use super::bio; + use super::{bio, probe}; use crate::{ common::{ ascii, @@ -23,6 +36,7 @@ mod _ssl { }, socket::{self, PySocket}, vm::{ + PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak}, convert::{ToPyException, ToPyObject}, exceptions, @@ -32,7 +46,6 @@ mod _ssl { }, types::Constructor, utils::ToCString, - PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }, }; use crossbeam_utils::atomic::AtomicCell; @@ -42,22 +55,21 @@ mod _ssl { error::ErrorStack, nid::Nid, ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode}, - x509::{self, X509Ref, X509}, + x509::{self, X509, X509Ref}, }; use openssl_sys as sys; + use rustpython_vm::ospath::OsPath; use std::{ ffi::CStr, fmt, io::{Read, Write}, + path::Path, time::Instant, }; // Constants #[pyattr] use sys::{ - SSL_OP_NO_SSLv2 as OP_NO_SSLv2, - SSL_OP_NO_SSLv3 as OP_NO_SSLv3, - SSL_OP_NO_TLSv1 as OP_NO_TLSv1, // TODO: so many more of these SSL_AD_DECODE_ERROR as ALERT_DESCRIPTION_DECODE_ERROR, SSL_AD_ILLEGAL_PARAMETER as ALERT_DESCRIPTION_ILLEGAL_PARAMETER, @@ -77,7 +89,10 @@ mod _ssl { // X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, SSL_ERROR_ZERO_RETURN, SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE, + SSL_OP_NO_SSLv2 as OP_NO_SSLv2, + SSL_OP_NO_SSLv3 as OP_NO_SSLv3, SSL_OP_NO_TICKET as OP_NO_TICKET, + SSL_OP_NO_TLSv1 as OP_NO_TLSv1, SSL_OP_SINGLE_DH_USE as OP_SINGLE_DH_USE, }; @@ -283,7 +298,7 @@ mod _ssl { if ptr.is_null() { None } else { - Some(Asn1Object::from_ptr(ptr)) + Some(unsafe { Asn1Object::from_ptr(ptr) }) } } @@ -354,21 +369,43 @@ mod _ssl { .ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}"))) } - #[pyfunction] - fn get_default_verify_paths() -> (String, String, String, String) { - macro_rules! convert { - ($f:ident) => { - CStr::from_ptr(sys::$f()).to_string_lossy().into_owned() - }; - } - unsafe { - ( - convert!(X509_get_default_cert_file_env), - convert!(X509_get_default_cert_file), - convert!(X509_get_default_cert_dir_env), - convert!(X509_get_default_cert_dir), - ) + fn get_cert_file_dir() -> (&'static Path, &'static Path) { + let probe = probe(); + // on windows, these should be utf8 strings + fn path_from_bytes(c: &CStr) -> &Path { + #[cfg(unix)] + { + use std::os::unix::ffi::OsStrExt; + std::ffi::OsStr::from_bytes(c.to_bytes()).as_ref() + } + #[cfg(windows)] + { + c.to_str().unwrap().as_ref() + } } + let cert_file = probe.cert_file.as_deref().unwrap_or_else(|| { + path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) }) + }); + let cert_dir = probe.cert_dir.as_deref().unwrap_or_else(|| { + path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) }) + }); + (cert_file, cert_dir) + } + + #[pyfunction] + fn get_default_verify_paths( + vm: &VirtualMachine, + ) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> { + let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) } + .to_str() + .unwrap(); + let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) } + .to_str() + .unwrap(); + let (cert_file, cert_dir) = get_cert_file_dir(); + let cert_file = OsPath::new_str(cert_file).filename(vm)?; + let cert_dir = OsPath::new_str(cert_dir).filename(vm)?; + Ok((cert_file_env, cert_file, cert_dir_env, cert_dir)) } #[pyfunction(name = "RAND_status")] @@ -564,7 +601,7 @@ mod _ssl { return Err(vm.new_value_error( "Cannot set verify_mode to CERT_NONE when check_hostname is enabled." .to_owned(), - )) + )); } CertRequirements::None => SslVerifyMode::NONE, CertRequirements::Optional => SslVerifyMode::PEER, @@ -590,9 +627,18 @@ mod _ssl { #[pymethod] fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> { - self.builder() - .set_default_verify_paths() - .map_err(|e| convert_openssl_error(vm, e)) + cfg_if::cfg_if! { + if #[cfg(openssl_vendored)] { + let (cert_file, cert_dir) = get_cert_file_dir(); + self.builder() + .load_verify_locations(Some(cert_file), Some(cert_dir)) + .map_err(|e| convert_openssl_error(vm, e)) + } else { + self.builder() + .set_default_verify_paths() + .map_err(|e| convert_openssl_error(vm, e)) + } + } } #[pymethod] @@ -612,7 +658,11 @@ mod _ssl { Ok(pbuf.to_vec()) })?; ctx.set_alpn_select_callback(move |_, client| { - ssl::select_next_proto(&server, client).ok_or(ssl::AlpnError::NOACK) + let proto = + ssl::select_next_proto(&server, client).ok_or(ssl::AlpnError::NOACK)?; + let pos = memchr::memmem::find(client, proto) + .expect("selected alpn proto should be present in client protos"); + Ok(&client[pos..proto.len()]) }); Ok(()) } @@ -635,6 +685,12 @@ mod _ssl { vm.new_type_error("cafile, capath and cadata cannot be all omitted".to_owned()) ); } + if let Some(cafile) = &args.cafile { + cafile.ensure_no_nul(vm)? + } + if let Some(capath) = &args.capath { + capath.ensure_no_nul(vm)? + } #[cold] fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef { @@ -643,6 +699,8 @@ mod _ssl { ) } + let mut ctx = self.builder(); + // validate cadata type and load cadata if let Some(cadata) = args.cadata { let certs = match cadata { @@ -655,7 +713,6 @@ mod _ssl { Either::B(b) => b.with_ref(x509_stack_from_der), }; let certs = certs.map_err(|e| convert_openssl_error(vm, e))?; - let mut ctx = self.builder(); let store = ctx.cert_store_mut(); for cert in certs { store @@ -665,29 +722,11 @@ mod _ssl { } if args.cafile.is_some() || args.capath.is_some() { - let cafile = args.cafile.map(|s| s.to_cstring(vm)).transpose()?; - let capath = args.capath.map(|s| s.to_cstring(vm)).transpose()?; - let ret = unsafe { - let ctx = self.ctx.write(); - sys::SSL_CTX_load_verify_locations( - ctx.as_ptr(), - cafile - .as_ref() - .map_or_else(std::ptr::null, |cs| cs.as_ptr()), - capath - .as_ref() - .map_or_else(std::ptr::null, |cs| cs.as_ptr()), - ) - }; - if ret != 1 { - let errno = crate::common::os::last_posix_errno(); - let err = if errno != 0 { - crate::vm::stdlib::os::errno_err(vm) - } else { - convert_openssl_error(vm, ErrorStack::get()) - }; - return Err(err); - } + ctx.load_verify_locations( + args.cafile.as_ref().map(|s| s.as_str().as_ref()), + args.capath.as_ref().map(|s| s.as_str().as_ref()), + ) + .map_err(|e| convert_openssl_error(vm, e))?; } Ok(()) @@ -967,11 +1006,7 @@ mod _ssl { #[pymethod] fn version(&self) -> Option<&'static str> { let v = self.stream.read().ssl().version_str(); - if v == "unknown" { - None - } else { - Some(v) - } + if v == "unknown" { None } else { Some(v) } } #[pymethod] @@ -1019,7 +1054,7 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The handshake operation timed out".to_owned(), - )) + )); } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} @@ -1045,7 +1080,7 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The write operation timed out".to_owned(), - )) + )); } SelectRet::Closed => return Err(socket_closed_error(vm)), _ => {} @@ -1061,7 +1096,7 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The write operation timed out".to_owned(), - )) + )); } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} @@ -1113,7 +1148,7 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The read operation timed out".to_owned(), - )) + )); } SelectRet::Nonblocking => {} _ => { @@ -1346,13 +1381,13 @@ mod _ssl { #[cfg(target_os = "android")] mod android { use super::convert_openssl_error; - use crate::vm::{builtins::PyBaseExceptionRef, VirtualMachine}; + use crate::vm::{VirtualMachine, builtins::PyBaseExceptionRef}; use openssl::{ ssl::SslContextBuilder, - x509::{store::X509StoreBuilder, X509}, + x509::{X509, store::X509StoreBuilder}, }; use std::{ - fs::{read_dir, File}, + fs::{File, read_dir}, io::Read, path::Path, }; @@ -1426,8 +1461,8 @@ mod windows {} mod ossl101 { #[pyattr] use openssl_sys::{ - SSL_OP_NO_TLSv1_1 as OP_NO_TLSv1_1, SSL_OP_NO_TLSv1_2 as OP_NO_TLSv1_2, - SSL_OP_NO_COMPRESSION as OP_NO_COMPRESSION, + SSL_OP_NO_COMPRESSION as OP_NO_COMPRESSION, SSL_OP_NO_TLSv1_1 as OP_NO_TLSv1_1, + SSL_OP_NO_TLSv1_2 as OP_NO_TLSv1_2, }; } @@ -1445,15 +1480,15 @@ mod windows { use crate::{ common::ascii, vm::{ + PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyFrozenSet, PyStrRef}, convert::ToPyException, - PyObjectRef, PyPayload, PyResult, VirtualMachine, }, }; #[pyfunction] fn enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult> { - use schannel::{cert_context::ValidUses, cert_store::CertStore, RawPointer}; + use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore}; use windows_sys::Win32::Security::Cryptography; // TODO: check every store for it, not just 2 of them: diff --git a/stdlib/src/statistics.rs b/stdlib/src/statistics.rs index 356bfd66f9..72e5d129a0 100644 --- a/stdlib/src/statistics.rs +++ b/stdlib/src/statistics.rs @@ -2,7 +2,7 @@ pub(crate) use _statistics::make_module; #[pymodule] mod _statistics { - use crate::vm::{function::ArgIntoFloat, PyResult, VirtualMachine}; + use crate::vm::{PyResult, VirtualMachine, function::ArgIntoFloat}; // See https://github.com/python/cpython/blob/6846d6712a0894f8e1a91716c11dd79f42864216/Modules/_statisticsmodule.c#L28-L120 #[allow(clippy::excessive_precision)] diff --git a/stdlib/src/syslog.rs b/stdlib/src/syslog.rs index 9879f3ffaf..3b36f9ea74 100644 --- a/stdlib/src/syslog.rs +++ b/stdlib/src/syslog.rs @@ -6,10 +6,10 @@ pub(crate) use syslog::make_module; mod syslog { use crate::common::lock::PyRwLock; use crate::vm::{ + PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyStr, PyStrRef}, function::{OptionalArg, OptionalOption}, utils::ToCString, - PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use std::{ffi::CStr, os::raw::c_char}; @@ -49,7 +49,7 @@ mod syslog { impl GlobalIdent { fn as_ptr(&self) -> *const c_char { match self { - GlobalIdent::Explicit(ref cstr) => cstr.as_ptr(), + GlobalIdent::Explicit(cstr) => cstr.as_ptr(), GlobalIdent::Implicit => std::ptr::null(), } } diff --git a/stdlib/src/termios.rs b/stdlib/src/termios.rs index 84c12b2b68..5c49d62a3c 100644 --- a/stdlib/src/termios.rs +++ b/stdlib/src/termios.rs @@ -3,10 +3,10 @@ pub(crate) use self::termios::make_module; #[pymodule] mod termios { use crate::vm::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytes, PyInt, PyListRef, PyTypeRef}, common::os::ErrorExt, convert::ToPyObject, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use termios::Termios; @@ -55,9 +55,9 @@ mod termios { ))] #[pyattr] use libc::{ - FIONCLEX, FIONREAD, TIOCEXCL, TIOCMBIC, TIOCMBIS, TIOCMGET, TIOCMSET, TIOCM_CAR, TIOCM_CD, - TIOCM_CTS, TIOCM_DSR, TIOCM_DTR, TIOCM_LE, TIOCM_RI, TIOCM_RNG, TIOCM_RTS, TIOCM_SR, - TIOCM_ST, TIOCNXCL, TIOCSCTTY, + FIONCLEX, FIONREAD, TIOCEXCL, TIOCM_CAR, TIOCM_CD, TIOCM_CTS, TIOCM_DSR, TIOCM_DTR, + TIOCM_LE, TIOCM_RI, TIOCM_RNG, TIOCM_RTS, TIOCM_SR, TIOCM_ST, TIOCMBIC, TIOCMBIS, TIOCMGET, + TIOCMSET, TIOCNXCL, TIOCSCTTY, }; #[cfg(any(target_os = "android", target_os = "linux"))] #[pyattr] @@ -100,12 +100,6 @@ mod termios { ))] #[pyattr] use termios::os::target::TCSASOFT; - #[cfg(any(target_os = "android", target_os = "linux"))] - #[pyattr] - use termios::os::target::{ - B1000000, B1152000, B1500000, B2000000, B2500000, B3000000, B3500000, B4000000, B500000, - B576000, CBAUDEX, - }; #[cfg(any( target_os = "android", target_os = "freebsd", @@ -116,6 +110,12 @@ mod termios { ))] #[pyattr] use termios::os::target::{B460800, B921600}; + #[cfg(any(target_os = "android", target_os = "linux"))] + #[pyattr] + use termios::os::target::{ + B500000, B576000, B1000000, B1152000, B1500000, B2000000, B2500000, B3000000, B3500000, + B4000000, CBAUDEX, + }; #[cfg(any( target_os = "android", target_os = "illumos", @@ -154,16 +154,17 @@ mod termios { use termios::os::target::{VSWTCH, VSWTCH as VSWTC}; #[pyattr] use termios::{ + B0, B50, B75, B110, B134, B150, B200, B300, B600, B1200, B1800, B2400, B4800, B9600, + B19200, B38400, BRKINT, CLOCAL, CREAD, CS5, CS6, CS7, CS8, CSIZE, CSTOPB, ECHO, ECHOE, + ECHOK, ECHONL, HUPCL, ICANON, ICRNL, IEXTEN, IGNBRK, IGNCR, IGNPAR, INLCR, INPCK, ISIG, + ISTRIP, IXANY, IXOFF, IXON, NOFLSH, OCRNL, ONLCR, ONLRET, ONOCR, OPOST, PARENB, PARMRK, + PARODD, TCIFLUSH, TCIOFF, TCIOFLUSH, TCION, TCOFLUSH, TCOOFF, TCOON, TCSADRAIN, TCSAFLUSH, + TCSANOW, TOSTOP, VEOF, VEOL, VERASE, VINTR, VKILL, VMIN, VQUIT, VSTART, VSTOP, VSUSP, + VTIME, os::target::{ - B115200, B230400, B57600, CRTSCTS, ECHOCTL, ECHOKE, ECHOPRT, EXTA, EXTB, FLUSHO, + B57600, B115200, B230400, CRTSCTS, ECHOCTL, ECHOKE, ECHOPRT, EXTA, EXTB, FLUSHO, IMAXBEL, NCCS, PENDIN, VDISCARD, VEOL2, VLNEXT, VREPRINT, VWERASE, }, - B0, B110, B1200, B134, B150, B1800, B19200, B200, B2400, B300, B38400, B4800, B50, B600, - B75, B9600, BRKINT, CLOCAL, CREAD, CS5, CS6, CS7, CS8, CSIZE, CSTOPB, ECHO, ECHOE, ECHOK, - ECHONL, HUPCL, ICANON, ICRNL, IEXTEN, IGNBRK, IGNCR, IGNPAR, INLCR, INPCK, ISIG, ISTRIP, - IXANY, IXOFF, IXON, NOFLSH, OCRNL, ONLCR, ONLRET, ONOCR, OPOST, PARENB, PARMRK, PARODD, - TCIFLUSH, TCIOFF, TCIOFLUSH, TCION, TCOFLUSH, TCOOFF, TCOON, TCSADRAIN, TCSAFLUSH, TCSANOW, - TOSTOP, VEOF, VEOL, VERASE, VINTR, VKILL, VMIN, VQUIT, VSTART, VSTOP, VSUSP, VTIME, }; #[pyfunction] diff --git a/stdlib/src/unicodedata.rs b/stdlib/src/unicodedata.rs index 70483073a7..49f3ef6250 100644 --- a/stdlib/src/unicodedata.rs +++ b/stdlib/src/unicodedata.rs @@ -5,8 +5,8 @@ // spell-checker:ignore nfkc unistr unidata use crate::vm::{ - builtins::PyModule, builtins::PyStr, convert::TryFromBorrowedObject, PyObject, PyObjectRef, - PyPayload, PyRef, PyResult, VirtualMachine, + PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyModule, + builtins::PyStr, convert::TryFromBorrowedObject, }; pub fn make_module(vm: &VirtualMachine) -> PyRef { @@ -61,14 +61,14 @@ impl<'a> TryFromBorrowedObject<'a> for NormalizeForm { #[pymodule] mod unicodedata { use crate::vm::{ - builtins::PyStrRef, function::OptionalArg, PyObjectRef, PyPayload, PyRef, PyResult, - VirtualMachine, + PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyStrRef, + function::OptionalArg, }; use itertools::Itertools; use ucd::{Codepoint, EastAsianWidth}; use unic_char_property::EnumeratedCharProperty; use unic_normal::StrNormalForm; - use unic_ucd_age::{Age, UnicodeVersion, UNICODE_VERSION}; + use unic_ucd_age::{Age, UNICODE_VERSION, UnicodeVersion}; use unic_ucd_bidi::BidiClass; use unic_ucd_category::GeneralCategory; diff --git a/stdlib/src/uuid.rs b/stdlib/src/uuid.rs index e5434da0e1..0bd7d0db8a 100644 --- a/stdlib/src/uuid.rs +++ b/stdlib/src/uuid.rs @@ -5,33 +5,19 @@ mod _uuid { use crate::{builtins::PyNone, vm::VirtualMachine}; use mac_address::get_mac_address; use once_cell::sync::OnceCell; - use rand::Rng; - use std::time::{Duration, SystemTime}; - use uuid::{ - v1::{Context, Timestamp}, - Uuid, - }; + use uuid::{Context, Uuid, timestamp::Timestamp}; fn get_node_id() -> [u8; 6] { match get_mac_address() { Ok(Some(_ma)) => get_mac_address().unwrap().unwrap().bytes(), - _ => rand::thread_rng().gen::<[u8; 6]>(), + _ => rand::random::<[u8; 6]>(), } } - pub fn now_unix_duration() -> Duration { - use std::time::UNIX_EPOCH; - - let now = SystemTime::now(); - now.duration_since(UNIX_EPOCH) - .expect("SystemTime before UNIX EPOCH!") - } - #[pyfunction] fn generate_time_safe() -> (Vec, PyNone) { static CONTEXT: Context = Context::new(0); - let now = now_unix_duration(); - let ts = Timestamp::from_unix(&CONTEXT, now.as_secs(), now.subsec_nanos()); + let ts = Timestamp::now(&CONTEXT); static NODE_ID: OnceCell<[u8; 6]> = OnceCell::new(); let unique_node_id = NODE_ID.get_or_init(get_node_id); diff --git a/stdlib/src/zlib.rs b/stdlib/src/zlib.rs index 83a7535c33..19ed659bbb 100644 --- a/stdlib/src/zlib.rs +++ b/stdlib/src/zlib.rs @@ -5,49 +5,37 @@ pub(crate) use zlib::make_module; #[pymodule] mod zlib { use crate::vm::{ - builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyIntRef, PyTypeRef}, + PyObject, PyPayload, PyResult, VirtualMachine, + builtins::{PyBaseExceptionRef, PyBytesRef, PyIntRef, PyTypeRef}, common::lock::PyMutex, convert::TryFromBorrowedObject, function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg}, - PyObject, PyPayload, PyResult, VirtualMachine, + types::Constructor, }; use adler32::RollingAdler32 as Adler32; - use crossbeam_utils::atomic::AtomicCell; use flate2::{ - write::ZlibEncoder, Compress, Compression, Decompress, FlushCompress, FlushDecompress, - Status, + Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status, + write::ZlibEncoder, }; use std::io::Write; - #[cfg(not(feature = "zlib"))] - mod constants { - pub const Z_NO_COMPRESSION: i32 = 0; - pub const Z_BEST_COMPRESSION: i32 = 9; - pub const Z_BEST_SPEED: i32 = 1; - pub const Z_DEFAULT_COMPRESSION: i32 = -1; - pub const Z_NO_FLUSH: i32 = 0; - pub const Z_PARTIAL_FLUSH: i32 = 1; - pub const Z_SYNC_FLUSH: i32 = 2; - pub const Z_FULL_FLUSH: i32 = 3; - // not sure what the value here means, but it's the only compression method zlibmodule - // supports, so it doesn't really matter - pub const Z_DEFLATED: i32 = 8; - } - #[cfg(feature = "zlib")] - use libz_sys as constants; - #[pyattr] - use constants::{ - Z_BEST_COMPRESSION, Z_BEST_SPEED, Z_DEFAULT_COMPRESSION, Z_DEFLATED as DEFLATED, - Z_FULL_FLUSH, Z_NO_COMPRESSION, Z_NO_FLUSH, Z_PARTIAL_FLUSH, Z_SYNC_FLUSH, + use libz_sys::{ + Z_BEST_COMPRESSION, Z_BEST_SPEED, Z_BLOCK, Z_DEFAULT_COMPRESSION, Z_DEFAULT_STRATEGY, + Z_DEFLATED as DEFLATED, Z_FILTERED, Z_FINISH, Z_FIXED, Z_FULL_FLUSH, Z_HUFFMAN_ONLY, + Z_NO_COMPRESSION, Z_NO_FLUSH, Z_PARTIAL_FLUSH, Z_RLE, Z_SYNC_FLUSH, Z_TREES, }; - #[cfg(feature = "zlib")] + // we're statically linking libz-rs, so the compile-time and runtime + // versions will always be the same + #[pyattr(name = "ZLIB_RUNTIME_VERSION")] #[pyattr] - use libz_sys::{ - Z_BLOCK, Z_DEFAULT_STRATEGY, Z_FILTERED, Z_FINISH, Z_FIXED, Z_HUFFMAN_ONLY, Z_RLE, Z_TREES, + const ZLIB_VERSION: &str = unsafe { + match std::ffi::CStr::from_ptr(libz_sys::zlibVersion()).to_str() { + Ok(s) => s, + Err(_) => unreachable!(), + } }; - use rustpython_vm::types::Constructor; // copied from zlibmodule.c (commit 530f506ac91338) #[pyattr] @@ -102,15 +90,10 @@ mod zlib { } = args; let level = level.ok_or_else(|| new_zlib_error("Bad compression level", vm))?; - let encoded_bytes = if args.wbits.value == MAX_WBITS { - let mut encoder = ZlibEncoder::new(Vec::new(), level); - data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap()); - encoder.finish().unwrap() - } else { - let mut inner = CompressInner::new(InitOptions::new(wbits.value, vm)?.compress(level)); - data.with_ref(|input_bytes| inner.compress(input_bytes, vm))?; - inner.flush(vm)? - }; + let compress = InitOptions::new(wbits.value, vm)?.compress(level); + let mut encoder = ZlibEncoder::new_with_compress(Vec::new(), compress); + data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap()); + let encoded_bytes = encoder.finish().unwrap(); Ok(vm.ctx.new_bytes(encoded_bytes)) } @@ -119,11 +102,11 @@ mod zlib { header: bool, // [De]Compress::new_with_window_bits is only enabled for zlib; miniz_oxide doesn't // support wbits (yet?) - #[cfg(feature = "zlib")] wbits: u8, }, - #[cfg(feature = "zlib")] - Gzip { wbits: u8 }, + Gzip { + wbits: u8, + }, } impl InitOptions { @@ -131,12 +114,12 @@ mod zlib { let header = wbits > 0; let wbits = wbits.unsigned_abs(); match wbits { - 9..=15 => Ok(InitOptions::Standard { - header, - #[cfg(feature = "zlib")] - wbits, - }), - #[cfg(feature = "zlib")] + // TODO: wbits = 0 should be a valid option: + // > windowBits can also be zero to request that inflate use the window size in + // > the zlib header of the compressed stream. + // but flate2 doesn't expose it + // 0 => ... + 9..=15 => Ok(InitOptions::Standard { header, wbits }), 25..=31 => Ok(InitOptions::Gzip { wbits: wbits - 16 }), _ => Err(vm.new_value_error("Invalid initialization option".to_owned())), } @@ -144,51 +127,95 @@ mod zlib { fn decompress(self) -> Decompress { match self { - #[cfg(not(feature = "zlib"))] - Self::Standard { header } => Decompress::new(header), - #[cfg(feature = "zlib")] Self::Standard { header, wbits } => Decompress::new_with_window_bits(header, wbits), - #[cfg(feature = "zlib")] Self::Gzip { wbits } => Decompress::new_gzip(wbits), } } fn compress(self, level: Compression) -> Compress { match self { - #[cfg(not(feature = "zlib"))] - Self::Standard { header } => Compress::new(level, header), - #[cfg(feature = "zlib")] Self::Standard { header, wbits } => { Compress::new_with_window_bits(level, header, wbits) } - #[cfg(feature = "zlib")] Self::Gzip { wbits } => Compress::new_gzip(level, wbits), } } } + #[derive(Clone)] + struct Chunker<'a> { + data1: &'a [u8], + data2: &'a [u8], + } + impl<'a> Chunker<'a> { + fn new(data: &'a [u8]) -> Self { + Self { + data1: data, + data2: &[], + } + } + fn chain(data1: &'a [u8], data2: &'a [u8]) -> Self { + if data1.is_empty() { + Self { + data1: data2, + data2: &[], + } + } else { + Self { data1, data2 } + } + } + fn len(&self) -> usize { + self.data1.len() + self.data2.len() + } + fn is_empty(&self) -> bool { + self.data1.is_empty() + } + fn to_vec(&self) -> Vec { + [self.data1, self.data2].concat() + } + fn chunk(&self) -> &'a [u8] { + self.data1.get(..CHUNKSIZE).unwrap_or(self.data1) + } + fn advance(&mut self, consumed: usize) { + self.data1 = &self.data1[consumed..]; + if self.data1.is_empty() { + self.data1 = std::mem::take(&mut self.data2); + } + } + } + fn _decompress( - mut data: &[u8], + data: &[u8], + d: &mut Decompress, + bufsize: usize, + max_length: Option, + is_flush: bool, + zdict: Option<&ArgBytesLike>, + vm: &VirtualMachine, + ) -> PyResult<(Vec, bool)> { + let mut data = Chunker::new(data); + _decompress_chunks(&mut data, d, bufsize, max_length, is_flush, zdict, vm) + } + + fn _decompress_chunks( + data: &mut Chunker<'_>, d: &mut Decompress, bufsize: usize, max_length: Option, is_flush: bool, + zdict: Option<&ArgBytesLike>, vm: &VirtualMachine, ) -> PyResult<(Vec, bool)> { if data.is_empty() { return Ok((Vec::new(), true)); } + let max_length = max_length.unwrap_or(usize::MAX); let mut buf = Vec::new(); - loop { - let final_chunk = data.len() <= CHUNKSIZE; - let chunk = if final_chunk { - data - } else { - &data[..CHUNKSIZE] - }; - // if this is the final chunk, finish it + 'outer: loop { + let chunk = data.chunk(); let flush = if is_flush { - if final_chunk { + // if this is the final chunk, finish it + if chunk.len() == data.len() { FlushDecompress::Finish } else { FlushDecompress::None @@ -197,34 +224,43 @@ mod zlib { FlushDecompress::Sync }; loop { - let additional = if let Some(max_length) = max_length { - std::cmp::min(bufsize, max_length - buf.capacity()) - } else { - bufsize - }; + let additional = std::cmp::min(bufsize, max_length - buf.capacity()); if additional == 0 { return Ok((buf, false)); } - buf.reserve_exact(additional); + let prev_in = d.total_in(); - let status = d - .decompress_vec(chunk, &mut buf, flush) - .map_err(|_| new_zlib_error("invalid input data", vm))?; + let res = d.decompress_vec(chunk, &mut buf, flush); let consumed = d.total_in() - prev_in; - data = &data[consumed as usize..]; - let stream_end = status == Status::StreamEnd; - if stream_end || data.is_empty() { - // we've reached the end of the stream, we're done - buf.shrink_to_fit(); - return Ok((buf, stream_end)); - } else if !chunk.is_empty() && consumed == 0 { - // we're gonna need a bigger buffer - continue; - } else { - // next chunk - break; - } + + data.advance(consumed as usize); + + match res { + Ok(status) => { + let stream_end = status == Status::StreamEnd; + if stream_end || data.is_empty() { + // we've reached the end of the stream, we're done + buf.shrink_to_fit(); + return Ok((buf, stream_end)); + } else if !chunk.is_empty() && consumed == 0 { + // we're gonna need a bigger buffer + continue; + } else { + // next chunk + continue 'outer; + } + } + Err(e) => { + let Some(zdict) = e.needs_dictionary().and(zdict) else { + return Err(new_zlib_error(&e.to_string(), vm)); + }; + d.set_dictionary(&zdict.borrow_buf()) + .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + // now try the next chunk + continue 'outer; + } + }; } } } @@ -249,7 +285,8 @@ mod zlib { } = args; data.with_ref(|data| { let mut d = InitOptions::new(wbits.value, vm)?.decompress(); - let (buf, stream_end) = _decompress(data, &mut d, bufsize.value, None, false, vm)?; + let (buf, stream_end) = + _decompress(data, &mut d, bufsize.value, None, false, None, vm)?; if !stream_end { return Err(new_zlib_error( "Error -5 while decompressing data: incomplete or truncated stream", @@ -264,104 +301,117 @@ mod zlib { struct DecompressobjArgs { #[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")] wbits: ArgPrimitiveIndex, - #[cfg(feature = "zlib")] #[pyarg(any, optional)] - _zdict: OptionalArg, + zdict: OptionalArg, } #[pyfunction] fn decompressobj(args: DecompressobjArgs, vm: &VirtualMachine) -> PyResult { - #[allow(unused_mut)] let mut decompress = InitOptions::new(args.wbits.value, vm)?.decompress(); - #[cfg(feature = "zlib")] - if let OptionalArg::Present(_dict) = args._zdict { - // FIXME: always fails - // dict.with_ref(|d| decompress.set_dictionary(d)); + let zdict = args.zdict.into_option(); + if let Some(dict) = &zdict { + if args.wbits.value < 0 { + dict.with_ref(|d| decompress.set_dictionary(d)) + .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + } } + let inner = PyDecompressInner { + decompress: Some(decompress), + eof: false, + zdict, + unused_data: vm.ctx.empty_bytes.clone(), + unconsumed_tail: vm.ctx.empty_bytes.clone(), + }; Ok(PyDecompress { - decompress: PyMutex::new(decompress), - eof: AtomicCell::new(false), - unused_data: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), - unconsumed_tail: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), + inner: PyMutex::new(inner), }) } + + #[derive(Debug)] + struct PyDecompressInner { + decompress: Option, + zdict: Option, + eof: bool, + unused_data: PyBytesRef, + unconsumed_tail: PyBytesRef, + } + #[pyattr] #[pyclass(name = "Decompress")] #[derive(Debug, PyPayload)] struct PyDecompress { - decompress: PyMutex, - eof: AtomicCell, - unused_data: PyMutex, - unconsumed_tail: PyMutex, + inner: PyMutex, } + #[pyclass] impl PyDecompress { #[pygetset] fn eof(&self) -> bool { - self.eof.load() + self.inner.lock().eof } #[pygetset] fn unused_data(&self) -> PyBytesRef { - self.unused_data.lock().clone() + self.inner.lock().unused_data.clone() } #[pygetset] fn unconsumed_tail(&self) -> PyBytesRef { - self.unconsumed_tail.lock().clone() + self.inner.lock().unconsumed_tail.clone() } - fn save_unused_input( - &self, - d: &Decompress, + fn decompress_inner( + inner: &mut PyDecompressInner, data: &[u8], - stream_end: bool, - orig_in: u64, + bufsize: usize, + max_length: Option, + is_flush: bool, vm: &VirtualMachine, - ) { - let leftover = &data[(d.total_in() - orig_in) as usize..]; - - if stream_end && !leftover.is_empty() { - let mut unused_data = self.unused_data.lock(); - let unused: Vec<_> = unused_data - .as_bytes() - .iter() - .chain(leftover) - .copied() - .collect(); - *unused_data = vm.ctx.new_pyref(unused); + ) -> PyResult<(PyResult>, bool)> { + let Some(d) = &mut inner.decompress else { + return Err(new_zlib_error(USE_AFTER_FINISH_ERR, vm)); + }; + + let zdict = if is_flush { None } else { inner.zdict.as_ref() }; + + let prev_in = d.total_in(); + let (ret, stream_end) = + match _decompress(data, d, bufsize, max_length, is_flush, zdict, vm) { + Ok((buf, stream_end)) => (Ok(buf), stream_end), + Err(err) => (Err(err), false), + }; + let consumed = (d.total_in() - prev_in) as usize; + + // save unused input + let unconsumed = &data[consumed..]; + if !unconsumed.is_empty() { + if stream_end { + let unused = [inner.unused_data.as_bytes(), unconsumed].concat(); + inner.unused_data = vm.ctx.new_pyref(unused); + } else { + inner.unconsumed_tail = vm.ctx.new_bytes(unconsumed.to_vec()); + } + } else if !inner.unconsumed_tail.is_empty() { + inner.unconsumed_tail = vm.ctx.empty_bytes.clone(); } + + Ok((ret, stream_end)) } #[pymethod] fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult> { - let max_length = args.max_length.value; + let max_length: usize = args + .max_length + .map_or(0, |x| x.value) + .try_into() + .map_err(|_| vm.new_value_error("must be non-negative".to_owned()))?; let max_length = (max_length != 0).then_some(max_length); - let data = args.data.borrow_buf(); - let data = &*data; + let data = &*args.data.borrow_buf(); - let mut d = self.decompress.lock(); - let orig_in = d.total_in(); + let inner = &mut *self.inner.lock(); let (ret, stream_end) = - match _decompress(data, &mut d, DEF_BUF_SIZE, max_length, false, vm) { - Ok((buf, true)) => { - self.eof.store(true); - (Ok(buf), true) - } - Ok((buf, false)) => (Ok(buf), false), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, data, stream_end, orig_in, vm); - - let leftover = if stream_end { - b"" - } else { - &data[(d.total_in() - orig_in) as usize..] - }; + Self::decompress_inner(inner, data, DEF_BUF_SIZE, max_length, false, vm)?; - let mut unconsumed_tail = self.unconsumed_tail.lock(); - if !leftover.is_empty() || !unconsumed_tail.is_empty() { - *unconsumed_tail = PyBytes::from(leftover.to_owned()).into_ref(&vm.ctx); - } + inner.eof |= stream_end; ret } @@ -369,36 +419,22 @@ mod zlib { #[pymethod] fn flush(&self, length: OptionalArg, vm: &VirtualMachine) -> PyResult> { let length = match length { - OptionalArg::Present(l) => { - let l: isize = l.into(); - if l <= 0 { - return Err( - vm.new_value_error("length must be greater than zero".to_owned()) - ); - } else { - l as usize - } + OptionalArg::Present(ArgSize { value }) if value <= 0 => { + return Err(vm.new_value_error("length must be greater than zero".to_owned())); } + OptionalArg::Present(ArgSize { value }) => value as usize, OptionalArg::Missing => DEF_BUF_SIZE, }; - let mut data = self.unconsumed_tail.lock(); - let mut d = self.decompress.lock(); + let inner = &mut *self.inner.lock(); + let data = std::mem::replace(&mut inner.unconsumed_tail, vm.ctx.empty_bytes.clone()); - let orig_in = d.total_in(); + let (ret, _) = Self::decompress_inner(inner, &data, length, None, true, vm)?; - let (ret, stream_end) = match _decompress(&data, &mut d, length, None, true, vm) { - Ok((buf, stream_end)) => (Ok(buf), stream_end), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, &data, stream_end, orig_in, vm); - - *data = PyBytes::from(Vec::new()).into_ref(&vm.ctx); + if inner.eof { + inner.decompress = None; + } - // TODO: drop the inner decompressor, somehow - // if stream_end { - // - // } ret } } @@ -407,11 +443,8 @@ mod zlib { struct DecompressArgs { #[pyarg(positional)] data: ArgBytesLike, - #[pyarg( - any, - default = "rustpython_vm::function::ArgPrimitiveIndex { value: 0 }" - )] - max_length: ArgPrimitiveIndex, + #[pyarg(any, optional)] + max_length: OptionalArg, } #[derive(FromArgs)] @@ -421,15 +454,13 @@ mod zlib { level: Level, // only DEFLATED is valid right now, it's w/e #[pyarg(any, default = "DEFLATED")] - _method: i32, + method: i32, #[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")] wbits: ArgPrimitiveIndex, - #[pyarg(any, name = "_memLevel", default = "DEF_MEM_LEVEL")] - _mem_level: u8, - #[cfg(feature = "zlib")] + #[pyarg(any, name = "memLevel", default = "DEF_MEM_LEVEL")] + mem_level: u8, #[pyarg(any, default = "Z_DEFAULT_STRATEGY")] - _strategy: i32, - #[cfg(feature = "zlib")] + strategy: i32, #[pyarg(any, optional)] zdict: Option, } @@ -439,7 +470,6 @@ mod zlib { let CompressobjArgs { level, wbits, - #[cfg(feature = "zlib")] zdict, .. } = args; @@ -447,7 +477,6 @@ mod zlib { level.ok_or_else(|| vm.new_value_error("invalid initialization option".to_owned()))?; #[allow(unused_mut)] let mut compress = InitOptions::new(wbits.value, vm)?.compress(level); - #[cfg(feature = "zlib")] if let Some(zdict) = zdict { zdict.with_ref(|zdict| compress.set_dictionary(zdict).unwrap()); } @@ -458,8 +487,7 @@ mod zlib { #[derive(Debug)] struct CompressInner { - compress: Compress, - unconsumed: Vec, + compress: Option, } #[pyattr] @@ -477,10 +505,17 @@ mod zlib { data.with_ref(|b| inner.compress(b, vm)) } - // TODO: mode argument isn't used #[pymethod] - fn flush(&self, _mode: OptionalArg, vm: &VirtualMachine) -> PyResult> { - self.inner.lock().flush(vm) + fn flush(&self, mode: OptionalArg, vm: &VirtualMachine) -> PyResult> { + let mode = match mode.unwrap_or(Z_FINISH) { + Z_NO_FLUSH => return Ok(vec![]), + Z_PARTIAL_FLUSH => FlushCompress::Partial, + Z_SYNC_FLUSH => FlushCompress::Sync, + Z_FULL_FLUSH => FlushCompress::Full, + Z_FINISH => FlushCompress::Finish, + _ => return Err(new_zlib_error("invalid mode", vm)), + }; + self.inner.lock().flush(mode, vm) } // TODO: This is an optional feature of Compress @@ -497,59 +532,55 @@ mod zlib { impl CompressInner { fn new(compress: Compress) -> Self { Self { - compress, - unconsumed: Vec::new(), + compress: Some(compress), } } + + fn get_compress(&mut self, vm: &VirtualMachine) -> PyResult<&mut Compress> { + self.compress + .as_mut() + .ok_or_else(|| new_zlib_error(USE_AFTER_FINISH_ERR, vm)) + } + fn compress(&mut self, data: &[u8], vm: &VirtualMachine) -> PyResult> { - let orig_in = self.compress.total_in() as usize; - let mut cur_in = 0; - let unconsumed = std::mem::take(&mut self.unconsumed); + let c = self.get_compress(vm)?; let mut buf = Vec::new(); - 'outer: for chunk in unconsumed.chunks(CHUNKSIZE).chain(data.chunks(CHUNKSIZE)) { - while cur_in < chunk.len() { + for mut chunk in data.chunks(CHUNKSIZE) { + while !chunk.is_empty() { buf.reserve(DEF_BUF_SIZE); - let status = self - .compress - .compress_vec(&chunk[cur_in..], &mut buf, FlushCompress::None) - .map_err(|_| { - self.unconsumed.extend_from_slice(&data[cur_in..]); - new_zlib_error("error while compressing", vm) - })?; - cur_in = (self.compress.total_in() as usize) - orig_in; - match status { - Status::Ok => continue, - Status::StreamEnd => break 'outer, - _ => break, - } + let prev_in = c.total_in(); + c.compress_vec(chunk, &mut buf, FlushCompress::None) + .map_err(|_| new_zlib_error("error while compressing", vm))?; + let consumed = c.total_in() - prev_in; + chunk = &chunk[consumed as usize..]; } } - self.unconsumed.extend_from_slice(&data[cur_in..]); buf.shrink_to_fit(); Ok(buf) } - // TODO: flush mode (FlushDecompress) parameter - fn flush(&mut self, vm: &VirtualMachine) -> PyResult> { - let data = std::mem::take(&mut self.unconsumed); - let mut data_it = data.chunks(CHUNKSIZE); + fn flush(&mut self, mode: FlushCompress, vm: &VirtualMachine) -> PyResult> { + let c = self.get_compress(vm)?; let mut buf = Vec::new(); - loop { - let chunk = data_it.next().unwrap_or(&[]); + let status = loop { if buf.len() == buf.capacity() { buf.reserve(DEF_BUF_SIZE); } - let status = self - .compress - .compress_vec(chunk, &mut buf, FlushCompress::Finish) + let status = c + .compress_vec(&[], &mut buf, mode) .map_err(|_| new_zlib_error("error while compressing", vm))?; - match status { - Status::StreamEnd => break, - _ => continue, + if buf.len() != buf.capacity() { + break status; } + }; + + match status { + Status::Ok | Status::BufError => {} + Status::StreamEnd if mode == FlushCompress::Finish => self.compress = None, + Status::StreamEnd => return Err(new_zlib_error("unexpected eof", vm)), } buf.shrink_to_fit(); @@ -561,6 +592,8 @@ mod zlib { vm.new_exception_msg(vm.class("zlib", "error"), message.to_owned()) } + const USE_AFTER_FINISH_ERR: &str = "Error -2: inconsistent stream state"; + struct Level(Option); impl Level { @@ -592,130 +625,116 @@ mod zlib { #[pyattr] #[pyclass(name = "_ZlibDecompressor")] #[derive(Debug, PyPayload)] - pub struct ZlibDecompressor { - decompress: PyMutex, - unused_data: PyMutex, - unconsumed_tail: PyMutex, + struct ZlibDecompressor { + inner: PyMutex, + } + + #[derive(Debug)] + struct ZlibDecompressorInner { + decompress: Decompress, + unused_data: PyBytesRef, + input_buffer: Vec, + zdict: Option, + eof: bool, + needs_input: bool, } impl Constructor for ZlibDecompressor { - type Args = (); - - fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { - let decompress = Decompress::new(true); - let zlib_decompressor = Self { - decompress: PyMutex::new(decompress), - unused_data: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), - unconsumed_tail: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), + type Args = DecompressobjArgs; + + fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { + let mut decompress = InitOptions::new(args.wbits.value, vm)?.decompress(); + let zdict = args.zdict.into_option(); + if let Some(dict) = &zdict { + if args.wbits.value < 0 { + dict.with_ref(|d| decompress.set_dictionary(d)) + .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + } + } + let inner = ZlibDecompressorInner { + decompress, + unused_data: vm.ctx.empty_bytes.clone(), + input_buffer: Vec::new(), + zdict, + eof: false, + needs_input: true, }; - zlib_decompressor - .into_ref_with_type(vm, cls) - .map(Into::into) + Self { + inner: PyMutex::new(inner), + } + .into_ref_with_type(vm, cls) + .map(Into::into) } } #[pyclass(with(Constructor))] impl ZlibDecompressor { #[pygetset] - fn unused_data(&self) -> PyBytesRef { - self.unused_data.lock().clone() + fn eof(&self) -> bool { + self.inner.lock().eof } #[pygetset] - fn unconsumed_tail(&self) -> PyBytesRef { - self.unconsumed_tail.lock().clone() + fn unused_data(&self) -> PyBytesRef { + self.inner.lock().unused_data.clone() } - fn save_unused_input( - &self, - d: &Decompress, - data: &[u8], - stream_end: bool, - orig_in: u64, - vm: &VirtualMachine, - ) { - let leftover = &data[(d.total_in() - orig_in) as usize..]; - - if stream_end && !leftover.is_empty() { - let mut unused_data = self.unused_data.lock(); - let unused: Vec<_> = unused_data - .as_bytes() - .iter() - .chain(leftover) - .copied() - .collect(); - *unused_data = vm.ctx.new_pyref(unused); - } + #[pygetset] + fn needs_input(&self) -> bool { + self.inner.lock().needs_input } #[pymethod] - fn decompress(&self, args: PyBytesRef, vm: &VirtualMachine) -> PyResult> { - // let max_length = args.max_length.value; - // let max_length = (max_length != 0).then_some(max_length); - let max_length = None; - let data = args.as_bytes(); - - let mut d = self.decompress.lock(); - let orig_in = d.total_in(); - - let (ret, stream_end) = - match _decompress(data, &mut d, DEF_BUF_SIZE, max_length, false, vm) { - Ok((buf, true)) => { - // Eof is true - (Ok(buf), true) - } - Ok((buf, false)) => (Ok(buf), false), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, data, stream_end, orig_in, vm); + fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult> { + let max_length = args + .max_length + .into_option() + .and_then(|ArgSize { value }| usize::try_from(value).ok()); + let data = &*args.data.borrow_buf(); - let leftover = if stream_end { - b"" - } else { - &data[(d.total_in() - orig_in) as usize..] - }; + let inner = &mut *self.inner.lock(); - let mut unconsumed_tail = self.unconsumed_tail.lock(); - if !leftover.is_empty() || !unconsumed_tail.is_empty() { - *unconsumed_tail = PyBytes::from(leftover.to_owned()).into_ref(&vm.ctx); + if inner.eof { + return Err(vm.new_eof_error("End of stream already reached".to_owned())); } - ret - } + let input_buffer = &mut inner.input_buffer; + let d = &mut inner.decompress; - #[pymethod] - fn flush(&self, length: OptionalArg, vm: &VirtualMachine) -> PyResult> { - let length = match length { - OptionalArg::Present(l) => { - let l: isize = l.into(); - if l <= 0 { - return Err( - vm.new_value_error("length must be greater than zero".to_owned()) - ); - } else { - l as usize - } - } - OptionalArg::Missing => DEF_BUF_SIZE, - }; + let mut chunks = Chunker::chain(input_buffer, data); - let mut data = self.unconsumed_tail.lock(); - let mut d = self.decompress.lock(); + let zdict = inner.zdict.as_ref(); + let bufsize = DEF_BUF_SIZE; - let orig_in = d.total_in(); + let prev_len = chunks.len(); + let (ret, stream_end) = + match _decompress_chunks(&mut chunks, d, bufsize, max_length, false, zdict, vm) { + Ok((buf, stream_end)) => (Ok(buf), stream_end), + Err(err) => (Err(err), false), + }; + let consumed = prev_len - chunks.len(); - let (ret, stream_end) = match _decompress(&data, &mut d, length, None, true, vm) { - Ok((buf, stream_end)) => (Ok(buf), stream_end), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, &data, stream_end, orig_in, vm); + inner.eof |= stream_end; - *data = PyBytes::from(Vec::new()).into_ref(&vm.ctx); + if inner.eof { + inner.needs_input = false; + if !chunks.is_empty() { + inner.unused_data = vm.ctx.new_bytes(chunks.to_vec()); + } + } else if chunks.is_empty() { + input_buffer.clear(); + inner.needs_input = true; + } else { + inner.needs_input = false; + if let Some(n_consumed_from_data) = consumed.checked_sub(input_buffer.len()) { + input_buffer.clear(); + input_buffer.extend_from_slice(&data[n_consumed_from_data..]); + } else { + input_buffer.drain(..consumed); + input_buffer.extend_from_slice(data); + } + } - // TODO: drop the inner decompressor, somehow - // if stream_end { - // - // } ret } diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 6eaca281f8..acc645bb74 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -23,7 +23,7 @@ ast = ["rustpython-ast"] codegen = ["rustpython-codegen", "ast"] parser = ["rustpython-parser", "ast"] serde = ["dep:serde"] -wasmbind = ["chrono/wasmbind", "getrandom/js", "wasm-bindgen"] +wasmbind = ["chrono/wasmbind", "getrandom/wasm_js", "wasm-bindgen"] [dependencies] rustpython-compiler = { workspace = true, optional = true } @@ -99,6 +99,9 @@ uname = "0.1.1" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] rustyline = { workspace = true } which = "6" +errno = "0.3" +libloading = "0.8" +widestring = { workspace = true } [target.'cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))'.dependencies] num_cpus = "1.13.1" @@ -106,8 +109,7 @@ num_cpus = "1.13.1" [target.'cfg(windows)'.dependencies] junction = { workspace = true } schannel = { workspace = true } -widestring = { workspace = true } -winreg = "0.52" +winreg = "0.55" [target.'cfg(windows)'.dependencies.windows] version = "0.52.0" @@ -115,6 +117,7 @@ features = [ "Win32_Foundation", "Win32_System_LibraryLoader", "Win32_System_Threading", + "Win32_System_Time", "Win32_UI_Shell", ] @@ -122,6 +125,7 @@ features = [ workspace = true features = [ "Win32_Foundation", + "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", @@ -143,7 +147,7 @@ features = [ [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies] wasm-bindgen = { workspace = true, optional = true } -getrandom = { workspace = true, features = ["custom"] } +getrandom = { workspace = true } [build-dependencies] glob = { workspace = true } diff --git a/vm/src/anystr.rs b/vm/src/anystr.rs index f5bfab1c57..0fd1d8f2f6 100644 --- a/vm/src/anystr.rs +++ b/vm/src/anystr.rs @@ -1,9 +1,9 @@ use crate::{ + Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyIntRef, PyTuple}, cformat::cformat_string, convert::TryFromBorrowedObject, function::OptionalOption, - Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use num_traits::{cast::ToPrimitive, sign::Signed}; @@ -386,11 +386,7 @@ pub trait AnyStr { b'\n' => (keep, 1), b'\r' => { let is_rn = enumerated.next_if(|(_, ch)| **ch == b'\n').is_some(); - if is_rn { - (keep + keep, 2) - } else { - (keep, 1) - } + if is_rn { (keep + keep, 2) } else { (keep, 1) } } _ => continue, }; diff --git a/vm/src/buffer.rs b/vm/src/buffer.rs index bebc7b13a9..3b76002d04 100644 --- a/vm/src/buffer.rs +++ b/vm/src/buffer.rs @@ -1,9 +1,9 @@ use crate::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytesRef, PyTuple, PyTupleRef, PyTypeRef}, common::{static_cell, str::wchar_t}, convert::ToPyObject, function::{ArgBytesLike, ArgIntoBool, ArgIntoFloat}, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use half::f16; use itertools::Itertools; @@ -99,8 +99,8 @@ impl fmt::Debug for FormatType { impl FormatType { fn info(self, e: Endianness) -> &'static FormatInfo { - use mem::{align_of, size_of}; use FormatType::*; + use mem::{align_of, size_of}; macro_rules! native_info { ($t:ty) => {{ &FormatInfo { diff --git a/vm/src/builtins/asyncgenerator.rs b/vm/src/builtins/asyncgenerator.rs index f9ed4719d8..3aee327e5b 100644 --- a/vm/src/builtins/asyncgenerator.rs +++ b/vm/src/builtins/asyncgenerator.rs @@ -1,5 +1,6 @@ use super::{PyCode, PyGenericAlias, PyStrRef, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyBaseExceptionRef, class::PyClassImpl, coroutine::Coro, @@ -7,7 +8,6 @@ use crate::{ function::OptionalArg, protocol::PyIterReturn, types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; @@ -208,7 +208,7 @@ impl PyAsyncGenASend { AwaitableState::Closed => { return Err(vm.new_runtime_error( "cannot reuse already awaited __anext__()/asend()".to_owned(), - )) + )); } AwaitableState::Iter => val, // already running, all good AwaitableState::Init => { diff --git a/vm/src/builtins/bool.rs b/vm/src/builtins/bool.rs index 63bf6cff2d..93f5f4f1de 100644 --- a/vm/src/builtins/bool.rs +++ b/vm/src/builtins/bool.rs @@ -1,13 +1,13 @@ use super::{PyInt, PyStrRef, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, + VirtualMachine, class::PyClassImpl, convert::{IntoPyException, ToPyObject, ToPyResult}, function::OptionalArg, identifier, protocol::PyNumberMethods, types::{AsNumber, Constructor, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, - VirtualMachine, }; use malachite_bigint::Sign; use num_traits::Zero; diff --git a/vm/src/builtins/builtin_func.rs b/vm/src/builtins/builtin_func.rs index bc1b082bc9..ff3ef38d3a 100644 --- a/vm/src/builtins/builtin_func.rs +++ b/vm/src/builtins/builtin_func.rs @@ -1,10 +1,10 @@ -use super::{type_, PyStrInterned, PyStrRef, PyType}; +use super::{PyStrInterned, PyStrRef, PyType, type_}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, convert::TryFromObject, function::{FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags, PyNativeFn}, types::{Callable, Comparable, PyComparisonOp, Representable, Unconstructible}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use std::fmt; diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 40e6cf7b5c..36cf8cadcd 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -4,12 +4,14 @@ use super::{ PyType, PyTypeRef, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, anystr::{self, AnyStr}, atomic_func, byte::{bytes_from_object, value_from_object}, bytesinner::{ - bytes_decode, ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, - ByteInnerSplitOptions, ByteInnerTranslateOptions, DecodeArgs, PyBytesInner, + ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, ByteInnerSplitOptions, + ByteInnerTranslateOptions, DecodeArgs, PyBytesInner, bytes_decode, }, class::PyClassImpl, common::{ @@ -33,8 +35,6 @@ use crate::{ DefaultConstructor, Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - VirtualMachine, }; use bstr::ByteSlice; use std::mem::size_of; diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 1db6417836..e9f5adc8bb 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -2,11 +2,13 @@ use super::{ PositionIterInternal, PyDictRef, PyIntRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromBorrowedObject, TryFromObject, VirtualMachine, anystr::{self, AnyStr}, atomic_func, bytesinner::{ - bytes_decode, ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, - ByteInnerSplitOptions, ByteInnerTranslateOptions, DecodeArgs, PyBytesInner, + ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, ByteInnerSplitOptions, + ByteInnerTranslateOptions, DecodeArgs, PyBytesInner, bytes_decode, }, class::PyClassImpl, common::{hash::PyHash, lock::PyMutex}, @@ -23,8 +25,6 @@ use crate::{ AsBuffer, AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, Hashable, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromBorrowedObject, TryFromObject, VirtualMachine, }; use bstr::ByteSlice; use once_cell::sync::Lazy; diff --git a/vm/src/builtins/classmethod.rs b/vm/src/builtins/classmethod.rs index 02f836199e..94b6e0ca9f 100644 --- a/vm/src/builtins/classmethod.rs +++ b/vm/src/builtins/classmethod.rs @@ -1,9 +1,9 @@ use super::{PyBoundMethod, PyStr, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, common::lock::PyMutex, types::{Constructor, GetDescriptor, Initializer, Representable}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; /// classmethod(function) -> method diff --git a/vm/src/builtins/code.rs b/vm/src/builtins/code.rs index bedd71f241..d9d8895b19 100644 --- a/vm/src/builtins/code.rs +++ b/vm/src/builtins/code.rs @@ -4,6 +4,7 @@ use super::{PyStrRef, PyTupleRef, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::PyStrInterned, bytecode::{self, AsBag, BorrowedConstant, CodeFlags, Constant, ConstantBag}, class::{PyClassImpl, StaticType}, @@ -12,7 +13,6 @@ use crate::{ function::{FuncArgs, OptionalArg}, source_code::OneIndexed, types::Representable, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use malachite_bigint::BigInt; use num_traits::Zero; diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index fb5bc5066c..e665d1e27a 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -1,5 +1,6 @@ -use super::{float, PyStr, PyType, PyTypeRef}; +use super::{PyStr, PyType, PyTypeRef, float}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, convert::{ToPyObject, ToPyResult}, function::{ @@ -11,7 +12,6 @@ use crate::{ protocol::PyNumberMethods, stdlib::warnings, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use num_complex::Complex64; use num_traits::Zero; diff --git a/vm/src/builtins/coroutine.rs b/vm/src/builtins/coroutine.rs index db9592bd47..cca2db3293 100644 --- a/vm/src/builtins/coroutine.rs +++ b/vm/src/builtins/coroutine.rs @@ -1,12 +1,12 @@ use super::{PyCode, PyStrRef, PyType}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, coroutine::Coro, frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "coroutine")] diff --git a/vm/src/builtins/descriptor.rs b/vm/src/builtins/descriptor.rs index 4b10f63f21..9da4e1d87a 100644 --- a/vm/src/builtins/descriptor.rs +++ b/vm/src/builtins/descriptor.rs @@ -1,10 +1,10 @@ use super::{PyStr, PyStrInterned, PyType}; use crate::{ - builtins::{builtin_func::PyNativeMethod, type_, PyTypeRef}, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyTypeRef, builtin_func::PyNativeMethod, type_}, class::PyClassImpl, function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, types::{Callable, GetDescriptor, Representable, Unconstructible}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use rustpython_common::lock::PyRwLock; diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 68567e493f..e54d2a1931 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -1,13 +1,14 @@ use super::{ - set::PySetInner, IterStatus, PositionIterInternal, PyBaseExceptionRef, PyGenericAlias, - PyMappingProxy, PySet, PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef, + IterStatus, PositionIterInternal, PyBaseExceptionRef, PyGenericAlias, PyMappingProxy, PySet, + PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef, set::PySetInner, }; use crate::{ - atomic_func, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, + TryFromObject, atomic_func, builtins::{ + PyTuple, iter::{builtins_iter, builtins_reversed}, type_::PyAttributes, - PyTuple, }, class::{PyClassDef, PyClassImpl}, common::ascii, @@ -21,8 +22,6 @@ use crate::{ Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, }, vm::VirtualMachine, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, - TryFromObject, }; use once_cell::sync::Lazy; use rustpython_common::lock::PyMutex; diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 64d7c1ed36..aa84115074 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -3,12 +3,12 @@ use super::{ }; use crate::common::lock::{PyMutex, PyRwLock}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, convert::ToPyObject, function::OptionalArg, protocol::{PyIter, PyIterReturn}, types::{Constructor, IterNext, Iterable, SelfIter}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use malachite_bigint::BigInt; use num_traits::Zero; diff --git a/vm/src/builtins/filter.rs b/vm/src/builtins/filter.rs index 3b33ff766f..009a1b3eab 100644 --- a/vm/src/builtins/filter.rs +++ b/vm/src/builtins/filter.rs @@ -1,9 +1,9 @@ use super::{PyType, PyTypeRef}; use crate::{ + Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, protocol::{PyIter, PyIterReturn}, types::{Constructor, IterNext, Iterable, SelfIter}, - Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "filter", traverse)] diff --git a/vm/src/builtins/float.rs b/vm/src/builtins/float.rs index 1cd041b7b9..e33c25cb56 100644 --- a/vm/src/builtins/float.rs +++ b/vm/src/builtins/float.rs @@ -1,9 +1,11 @@ // spell-checker:ignore numer denom use super::{ - try_bigint_to_f64, PyByteArray, PyBytes, PyInt, PyIntRef, PyStr, PyStrRef, PyType, PyTypeRef, + PyByteArray, PyBytes, PyInt, PyIntRef, PyStr, PyStrRef, PyType, PyTypeRef, try_bigint_to_f64, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromBorrowedObject, TryFromObject, VirtualMachine, class::PyClassImpl, common::{float_ops, hash}, convert::{IntoPyException, ToPyObject, ToPyResult}, @@ -14,8 +16,6 @@ use crate::{ }, protocol::PyNumberMethods, types::{AsNumber, Callable, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromBorrowedObject, TryFromObject, VirtualMachine, }; use malachite_bigint::{BigInt, ToBigInt}; use num_complex::Complex64; @@ -124,8 +124,8 @@ fn inner_divmod(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult<(f64, f64)> { pub fn float_pow(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { if v1.is_zero() && v2.is_sign_negative() { - let msg = format!("{v1} cannot be raised to a negative power"); - Err(vm.new_zero_division_error(msg)) + let msg = "0.0 cannot be raised to a negative power"; + Err(vm.new_zero_division_error(msg.to_owned())) } else if v1.is_sign_negative() && (v2.floor() - v2).abs() > f64::EPSILON { let v1 = Complex64::new(v1, 0.); let v2 = Complex64::new(v2, 0.); diff --git a/vm/src/builtins/frame.rs b/vm/src/builtins/frame.rs index 3cc7d788fb..1fd031984a 100644 --- a/vm/src/builtins/frame.rs +++ b/vm/src/builtins/frame.rs @@ -4,11 +4,11 @@ use super::{PyCode, PyDictRef, PyIntRef, PyStrRef}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyRef, PyResult, VirtualMachine, class::PyClassImpl, frame::{Frame, FrameRef}, function::PySetterValue, types::{Representable, Unconstructible}, - AsObject, Context, Py, PyObjectRef, PyRef, PyResult, VirtualMachine, }; use num_traits::Zero; diff --git a/vm/src/builtins/function.rs b/vm/src/builtins/function.rs index eb5a142f0c..63cf8ac5c3 100644 --- a/vm/src/builtins/function.rs +++ b/vm/src/builtins/function.rs @@ -2,8 +2,8 @@ mod jitfunc; use super::{ - tuple::PyTupleTyped, PyAsyncGen, PyCode, PyCoroutine, PyDictRef, PyGenerator, PyStr, PyStrRef, - PyTupleRef, PyType, PyTypeRef, + PyAsyncGen, PyCode, PyCoroutine, PyDictRef, PyGenerator, PyStr, PyStrRef, PyTupleRef, PyType, + PyTypeRef, tuple::PyTupleTyped, }; #[cfg(feature = "jit")] use crate::common::lock::OnceCell; @@ -12,6 +12,7 @@ use crate::convert::ToPyObject; use crate::function::ArgMapping; use crate::object::{Traverse, TraverseFn}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, bytecode, class::PyClassImpl, frame::Frame, @@ -20,7 +21,6 @@ use crate::{ types::{ Callable, Comparable, Constructor, GetAttr, GetDescriptor, PyComparisonOp, Representable, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use itertools::Itertools; #[cfg(feature = "jit")] diff --git a/vm/src/builtins/function/jitfunc.rs b/vm/src/builtins/function/jitfunc.rs index d46458fc65..a46d9aa0f3 100644 --- a/vm/src/builtins/function/jitfunc.rs +++ b/vm/src/builtins/function/jitfunc.rs @@ -1,9 +1,9 @@ use crate::{ - builtins::{bool_, float, int, PyBaseExceptionRef, PyDictRef, PyFunction, PyStrInterned}, + AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, + builtins::{PyBaseExceptionRef, PyDictRef, PyFunction, PyStrInterned, bool_, float, int}, bytecode::CodeFlags, convert::ToPyObject, function::FuncArgs, - AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use num_traits::ToPrimitive; use rustpython_jit::{AbiValue, Args, CompiledCode, JitArgumentError, JitType}; diff --git a/vm/src/builtins/generator.rs b/vm/src/builtins/generator.rs index e0ec77006d..db9c263cb2 100644 --- a/vm/src/builtins/generator.rs +++ b/vm/src/builtins/generator.rs @@ -4,13 +4,13 @@ use super::{PyCode, PyStrRef, PyType}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, coroutine::Coro, frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "generator")] diff --git a/vm/src/builtins/genericalias.rs b/vm/src/builtins/genericalias.rs index 2746b03128..57c97ba62b 100644 --- a/vm/src/builtins/genericalias.rs +++ b/vm/src/builtins/genericalias.rs @@ -2,7 +2,8 @@ use once_cell::sync::Lazy; use super::type_; use crate::{ - atomic_func, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, atomic_func, builtins::{PyList, PyStr, PyTuple, PyTupleRef, PyType, PyTypeRef}, class::PyClassImpl, common::hash, @@ -13,8 +14,6 @@ use crate::{ AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, PyComparisonOp, Representable, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - VirtualMachine, }; use std::fmt; diff --git a/vm/src/builtins/getset.rs b/vm/src/builtins/getset.rs index 603ec021e2..c2e11b770a 100644 --- a/vm/src/builtins/getset.rs +++ b/vm/src/builtins/getset.rs @@ -3,10 +3,10 @@ */ use super::{PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, function::{IntoPyGetterFunc, IntoPySetterFunc, PyGetterFunc, PySetterFunc, PySetterValue}, types::{GetDescriptor, Unconstructible}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "getset_descriptor")] diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index eb05b4394a..db9c5ea4ed 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -1,5 +1,7 @@ -use super::{float, PyByteArray, PyBytes, PyStr, PyType, PyTypeRef}; +use super::{PyByteArray, PyBytes, PyStr, PyType, PyTypeRef, float}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, + TryFromBorrowedObject, VirtualMachine, builtins::PyStrRef, bytesinner::PyBytesInner, class::PyClassImpl, @@ -14,8 +16,6 @@ use crate::{ }, protocol::PyNumberMethods, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, - TryFromBorrowedObject, VirtualMachine, }; use malachite_bigint::{BigInt, Sign}; use num_integer::Integer; @@ -109,11 +109,7 @@ fn inner_pow(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult { } else if int1.is_zero() { 0 } else if int1 == &BigInt::from(-1) { - if int2.is_odd() { - -1 - } else { - 1 - } + if int2.is_odd() { -1 } else { 1 } } else { // missing feature: BigInt exp // practically, exp over u64 is not possible to calculate anyway @@ -426,11 +422,7 @@ impl PyInt { // based on rust-num/num-integer#10, should hopefully be published soon fn normalize(a: BigInt, n: &BigInt) -> BigInt { let a = a % n; - if a.is_negative() { - a + n - } else { - a - } + if a.is_negative() { a + n } else { a } } fn inverse(a: BigInt, n: &BigInt) -> Option { use num_integer::*; @@ -642,7 +634,7 @@ impl PyInt { Sign::Minus if !signed => { return Err( vm.new_overflow_error("can't convert negative int to unsigned".to_owned()) - ) + ); } Sign::NoSign => return Ok(vec![0u8; byte_len].into()), _ => {} diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 2fed23a5c4..5a47abfac7 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -4,12 +4,12 @@ use super::{PyInt, PyTupleRef, PyType}; use crate::{ + Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, function::ArgCallable, object::{Traverse, TraverseFn}, protocol::{PyIterReturn, PySequence, PySequenceMethods}, types::{IterNext, Iterable, SelfIter}, - Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use rustpython_common::{ lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}, @@ -28,7 +28,7 @@ pub enum IterStatus { unsafe impl Traverse for IterStatus { fn traverse(&self, tracer_fn: &mut TraverseFn) { match self { - IterStatus::Active(ref r) => r.traverse(tracer_fn), + IterStatus::Active(r) => r.traverse(tracer_fn), IterStatus::Exhausted => (), } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 33d6072cd5..3b9624694e 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -4,6 +4,7 @@ use crate::common::lock::{ PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, class::PyClassImpl, convert::ToPyObject, function::{ArgSize, FuncArgs, OptionalArg, PyComparisonValue}, @@ -18,7 +19,6 @@ use crate::{ }, utils::collection_repr, vm::VirtualMachine, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, }; use std::{fmt, ops::DerefMut}; diff --git a/vm/src/builtins/map.rs b/vm/src/builtins/map.rs index 44bacf587e..555e38c8b9 100644 --- a/vm/src/builtins/map.rs +++ b/vm/src/builtins/map.rs @@ -1,11 +1,11 @@ use super::{PyType, PyTypeRef}; use crate::{ + Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::PyTupleRef, class::PyClassImpl, function::PosArgs, protocol::{PyIter, PyIterReturn}, types::{Constructor, IterNext, Iterable, SelfIter}, - Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "map", traverse)] diff --git a/vm/src/builtins/mappingproxy.rs b/vm/src/builtins/mappingproxy.rs index 7b2386a39b..659562cec1 100644 --- a/vm/src/builtins/mappingproxy.rs +++ b/vm/src/builtins/mappingproxy.rs @@ -1,5 +1,6 @@ use super::{PyDict, PyDictRef, PyGenericAlias, PyList, PyTuple, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, class::PyClassImpl, convert::ToPyObject, @@ -10,7 +11,6 @@ use crate::{ AsMapping, AsNumber, AsSequence, Comparable, Constructor, Iterable, PyComparisonOp, Representable, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use once_cell::sync::Lazy; @@ -29,8 +29,8 @@ enum MappingProxyInner { unsafe impl Traverse for MappingProxyInner { fn traverse(&self, tracer_fn: &mut TraverseFn) { match self { - MappingProxyInner::Class(ref r) => r.traverse(tracer_fn), - MappingProxyInner::Mapping(ref arg) => arg.traverse(tracer_fn), + MappingProxyInner::Class(r) => r.traverse(tracer_fn), + MappingProxyInner::Mapping(arg) => arg.traverse(tracer_fn), } } } diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index dfbbdc5aaf..9426730f40 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -3,7 +3,8 @@ use super::{ PyTupleRef, PyType, PyTypeRef, }; use crate::{ - atomic_func, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromBorrowedObject, TryFromObject, VirtualMachine, atomic_func, buffer::FormatSpec, bytesinner::bytes_to_hex, class::PyClassImpl, @@ -24,8 +25,6 @@ use crate::{ AsBuffer, AsMapping, AsSequence, Comparable, Constructor, Hashable, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromBorrowedObject, TryFromObject, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; @@ -223,14 +222,16 @@ impl PyMemoryView { fn pos_from_multi_index(&self, indexes: &[isize], vm: &VirtualMachine) -> PyResult { match indexes.len().cmp(&self.desc.ndim()) { Ordering::Less => { - return Err(vm.new_not_implemented_error("sub-views are not implemented".to_owned())) + return Err( + vm.new_not_implemented_error("sub-views are not implemented".to_owned()) + ); } Ordering::Greater => { return Err(vm.new_type_error(format!( "cannot index {}-dimension view with {}-element tuple", self.desc.ndim(), indexes.len() - ))) + ))); } Ordering::Equal => (), } @@ -380,11 +381,7 @@ impl PyMemoryView { } }; ret = vm.bool_eq(&a_val, &b_val); - if let Ok(b) = ret { - !b - } else { - true - } + if let Ok(b) = ret { !b } else { true } }); ret } diff --git a/vm/src/builtins/module.rs b/vm/src/builtins/module.rs index de55b9deae..8c8f22cf58 100644 --- a/vm/src/builtins/module.rs +++ b/vm/src/builtins/module.rs @@ -1,11 +1,11 @@ use super::{PyDictRef, PyStr, PyStrRef, PyType, PyTypeRef}; use crate::{ - builtins::{pystr::AsPyStr, PyStrInterned}, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyStrInterned, pystr::AsPyStr}, class::PyClassImpl, convert::ToPyObject, function::{FuncArgs, PyMethodDef}, types::{GetAttr, Initializer, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "module")] @@ -183,11 +183,12 @@ impl Initializer for PyModule { type Args = ModuleInitArgs; fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { - debug_assert!(zelf - .class() - .slots - .flags - .has_feature(crate::types::PyTypeFlags::HAS_DICT)); + debug_assert!( + zelf.class() + .slots + .flags + .has_feature(crate::types::PyTypeFlags::HAS_DICT) + ); zelf.init_dict(vm.ctx.intern_str(args.name.as_str()), args.doc, vm); Ok(()) } diff --git a/vm/src/builtins/namespace.rs b/vm/src/builtins/namespace.rs index 441fd014f0..38146baa72 100644 --- a/vm/src/builtins/namespace.rs +++ b/vm/src/builtins/namespace.rs @@ -1,5 +1,6 @@ -use super::{tuple::IntoPyTuple, PyTupleRef, PyType}; +use super::{PyTupleRef, PyType, tuple::IntoPyTuple}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyDict, class::PyClassImpl, function::{FuncArgs, PyComparisonValue}, @@ -7,7 +8,6 @@ use crate::{ types::{ Comparable, Constructor, DefaultConstructor, Initializer, PyComparisonOp, Representable, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; /// A simple attribute-based namespace. diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index f783ee017c..cce1422d56 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -2,11 +2,11 @@ use super::{PyDictRef, PyList, PyStr, PyStrRef, PyType, PyTypeRef}; use crate::common::hash::PyHash; use crate::types::PyTypeFlags; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, convert::ToPyResult, function::{Either, FuncArgs, PyArithmeticValue, PyComparisonValue, PySetterValue}, types::{Constructor, PyComparisonOp}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use itertools::Itertools; @@ -113,7 +113,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) // )); // } - let state = if obj.dict().map_or(true, |d| d.is_empty()) { + let state = if obj.dict().is_none_or(|d| d.is_empty()) { vm.ctx.none() } else { // let state = object_get_dict(obj.clone(), obj.ctx()).unwrap(); diff --git a/vm/src/builtins/property.rs b/vm/src/builtins/property.rs index 61e1ff1692..5bfae5a081 100644 --- a/vm/src/builtins/property.rs +++ b/vm/src/builtins/property.rs @@ -5,10 +5,10 @@ use super::{PyStrRef, PyType, PyTypeRef}; use crate::common::lock::PyRwLock; use crate::function::{IntoFuncArgs, PosArgs}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, function::{FuncArgs, PySetterValue}, types::{Constructor, GetDescriptor, Initializer}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "property", traverse)] diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index 602ea2d37f..b542a5f191 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -1,8 +1,9 @@ use super::{ - builtins_iter, tuple::tuple_hash, PyInt, PyIntRef, PySlice, PyTupleRef, PyType, PyTypeRef, + PyInt, PyIntRef, PySlice, PyTupleRef, PyType, PyTypeRef, builtins_iter, tuple::tuple_hash, }; use crate::{ - atomic_func, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, atomic_func, class::PyClassImpl, common::hash::PyHash, function::{ArgIndex, FuncArgs, OptionalArg, PyComparisonValue}, @@ -11,8 +12,6 @@ use crate::{ AsMapping, AsSequence, Comparable, Hashable, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use malachite_bigint::{BigInt, Sign}; diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index 1ab1fe21cf..62a66c89b4 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -2,10 +2,11 @@ * Builtin set type with a sequence of unique items. */ use super::{ - builtins_iter, IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef, - PyType, PyTypeRef, + IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef, PyType, + PyTypeRef, builtins_iter, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, atomic_func, class::PyClassImpl, common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc}, @@ -21,7 +22,6 @@ use crate::{ }, utils::collection_repr, vm::VirtualMachine, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, }; use once_cell::sync::Lazy; use rustpython_common::{ diff --git a/vm/src/builtins/singletons.rs b/vm/src/builtins/singletons.rs index 65b171a262..da0c718c46 100644 --- a/vm/src/builtins/singletons.rs +++ b/vm/src/builtins/singletons.rs @@ -1,10 +1,10 @@ use super::{PyStrRef, PyType, PyTypeRef}; use crate::{ + Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, convert::ToPyObject, protocol::PyNumberMethods, types::{AsNumber, Constructor, Representable}, - Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "NoneType")] diff --git a/vm/src/builtins/slice.rs b/vm/src/builtins/slice.rs index 5da3649115..4194360f4a 100644 --- a/vm/src/builtins/slice.rs +++ b/vm/src/builtins/slice.rs @@ -2,13 +2,13 @@ // spell-checker:ignore sliceobject use super::{PyStrRef, PyTupleRef, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, common::hash::{PyHash, PyUHash}, convert::ToPyObject, function::{ArgIndex, FuncArgs, OptionalArg, PyComparisonValue}, sliceable::SaturatedSlice, types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use malachite_bigint::{BigInt, ToBigInt}; use num_traits::{One, Signed, Zero}; @@ -264,11 +264,7 @@ impl Comparable for PySlice { let eq = vm.identical_or_equal(zelf.start_ref(vm), other.start_ref(vm))? && vm.identical_or_equal(&zelf.stop, &other.stop)? && vm.identical_or_equal(zelf.step_ref(vm), other.step_ref(vm))?; - if op == PyComparisonOp::Ne { - !eq - } else { - eq - } + if op == PyComparisonOp::Ne { !eq } else { eq } } PyComparisonOp::Gt | PyComparisonOp::Ge => None .or_else(|| { diff --git a/vm/src/builtins/staticmethod.rs b/vm/src/builtins/staticmethod.rs index 59a5b18b5d..8e2333da7f 100644 --- a/vm/src/builtins/staticmethod.rs +++ b/vm/src/builtins/staticmethod.rs @@ -1,10 +1,10 @@ use super::{PyStr, PyType, PyTypeRef}; use crate::{ + Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, common::lock::PyMutex, function::FuncArgs, types::{Callable, Constructor, GetDescriptor, Initializer, Representable}, - Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "staticmethod", traverse)] diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index c0a096b211..bf9a9f679c 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -1,10 +1,12 @@ use super::{ + PositionIterInternal, PyBytesRef, PyDict, PyTupleRef, PyType, PyTypeRef, int::{PyInt, PyIntRef}, iter::IterStatus::{self, Exhausted}, - PositionIterInternal, PyBytesRef, PyDict, PyTupleRef, PyType, PyTypeRef, }; use crate::{ - anystr::{self, adjust_indices, AnyStr, AnyStrContainer, AnyStrWrapper}, + AsObject, Context, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, + TryFromBorrowedObject, VirtualMachine, + anystr::{self, AnyStr, AnyStrContainer, AnyStrWrapper, adjust_indices}, atomic_func, class::PyClassImpl, common::str::{BorrowedStr, PyStrKind, PyStrKindData}, @@ -20,8 +22,6 @@ use crate::{ AsMapping, AsNumber, AsSequence, Comparable, Constructor, Hashable, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, }, - AsObject, Context, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, - TryFromBorrowedObject, VirtualMachine, }; use ascii::{AsciiStr, AsciiString}; use bstr::ByteSlice; @@ -313,7 +313,7 @@ impl PyStr { /// # Safety /// Given `bytes` must be ascii pub unsafe fn new_ascii_unchecked(bytes: Vec) -> Self { - Self::new_str_unchecked(bytes, PyStrKind::Ascii) + unsafe { Self::new_str_unchecked(bytes, PyStrKind::Ascii) } } pub fn new_ref(zelf: impl Into, ctx: &Context) -> PyRef { @@ -904,11 +904,7 @@ impl PyStr { '\n' => 1, '\r' => { let is_rn = enumerated.next_if(|(_, ch)| *ch == '\n').is_some(); - if is_rn { - 2 - } else { - 1 - } + if is_rn { 2 } else { 1 } } '\x0b' | '\x0c' | '\x1c' | '\x1d' | '\x1e' | '\u{0085}' | '\u{2028}' | '\u{2029}' => ch.len_utf8(), @@ -1125,33 +1121,7 @@ impl PyStr { #[pymethod] fn expandtabs(&self, args: anystr::ExpandTabsArgs) -> String { - let tab_stop = args.tabsize(); - let mut expanded_str = String::with_capacity(self.byte_len()); - let mut tab_size = tab_stop; - let mut col_count = 0usize; - for ch in self.as_str().chars() { - match ch { - '\t' => { - let num_spaces = tab_size - col_count; - col_count += num_spaces; - let expand = " ".repeat(num_spaces); - expanded_str.push_str(&expand); - } - '\r' | '\n' => { - expanded_str.push(ch); - col_count = 0; - tab_size = 0; - } - _ => { - expanded_str.push(ch); - col_count += 1; - } - } - if col_count >= tab_size { - tab_size += tab_stop; - } - } - expanded_str + rustpython_common::str::expandtabs(self.as_str(), args.tabsize()) } #[pymethod] @@ -1591,80 +1561,6 @@ impl AsRef for PyExact { } } -#[cfg(test)] -mod tests { - use super::*; - use crate::Interpreter; - - #[test] - fn str_title() { - let tests = vec![ - (" Hello ", " hello "), - ("Hello ", "hello "), - ("Hello ", "Hello "), - ("Format This As Title String", "fOrMaT thIs aS titLe String"), - ("Format,This-As*Title;String", "fOrMaT,thIs-aS*titLe;String"), - ("Getint", "getInt"), - ("Greek Ωppercases ...", "greek ωppercases ..."), - ("Greek ῼitlecases ...", "greek ῳitlecases ..."), - ]; - for (title, input) in tests { - assert_eq!(PyStr::from(input).title().as_str(), title); - } - } - - #[test] - fn str_istitle() { - let pos = vec![ - "A", - "A Titlecased Line", - "A\nTitlecased Line", - "A Titlecased, Line", - "Greek Ωppercases ...", - "Greek ῼitlecases ...", - ]; - - for s in pos { - assert!(PyStr::from(s).istitle()); - } - - let neg = vec![ - "", - "a", - "\n", - "Not a capitalized String", - "Not\ta Titlecase String", - "Not--a Titlecase String", - "NOT", - ]; - for s in neg { - assert!(!PyStr::from(s).istitle()); - } - } - - #[test] - fn str_maketrans_and_translate() { - Interpreter::without_stdlib(Default::default()).enter(|vm| { - let table = vm.ctx.new_dict(); - table - .set_item("a", vm.ctx.new_str("🎅").into(), vm) - .unwrap(); - table.set_item("b", vm.ctx.none(), vm).unwrap(); - table - .set_item("c", vm.ctx.new_str(ascii!("xda")).into(), vm) - .unwrap(); - let translated = - PyStr::maketrans(table.into(), OptionalArg::Missing, OptionalArg::Missing, vm) - .unwrap(); - let text = PyStr::from("abc"); - let translated = text.translate(translated, vm).unwrap(); - assert_eq!(translated, "🎅xda".to_owned()); - let translated = text.translate(vm.ctx.new_int(3).into(), vm); - assert_eq!("TypeError", &*translated.unwrap_err().class().name(),); - }) - } -} - impl AnyStrWrapper for PyStrRef { type Str = str; fn as_ref(&self) -> &str { @@ -1806,3 +1702,77 @@ impl AsRef for PyStrInterned { self.as_str() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::Interpreter; + + #[test] + fn str_title() { + let tests = vec![ + (" Hello ", " hello "), + ("Hello ", "hello "), + ("Hello ", "Hello "), + ("Format This As Title String", "fOrMaT thIs aS titLe String"), + ("Format,This-As*Title;String", "fOrMaT,thIs-aS*titLe;String"), + ("Getint", "getInt"), + ("Greek Ωppercases ...", "greek ωppercases ..."), + ("Greek ῼitlecases ...", "greek ῳitlecases ..."), + ]; + for (title, input) in tests { + assert_eq!(PyStr::from(input).title().as_str(), title); + } + } + + #[test] + fn str_istitle() { + let pos = vec![ + "A", + "A Titlecased Line", + "A\nTitlecased Line", + "A Titlecased, Line", + "Greek Ωppercases ...", + "Greek ῼitlecases ...", + ]; + + for s in pos { + assert!(PyStr::from(s).istitle()); + } + + let neg = vec![ + "", + "a", + "\n", + "Not a capitalized String", + "Not\ta Titlecase String", + "Not--a Titlecase String", + "NOT", + ]; + for s in neg { + assert!(!PyStr::from(s).istitle()); + } + } + + #[test] + fn str_maketrans_and_translate() { + Interpreter::without_stdlib(Default::default()).enter(|vm| { + let table = vm.ctx.new_dict(); + table + .set_item("a", vm.ctx.new_str("🎅").into(), vm) + .unwrap(); + table.set_item("b", vm.ctx.none(), vm).unwrap(); + table + .set_item("c", vm.ctx.new_str(ascii!("xda")).into(), vm) + .unwrap(); + let translated = + PyStr::maketrans(table.into(), OptionalArg::Missing, OptionalArg::Missing, vm) + .unwrap(); + let text = PyStr::from("abc"); + let translated = text.translate(translated, vm).unwrap(); + assert_eq!(translated, "🎅xda".to_owned()); + let translated = text.translate(vm.ctx.new_int(3).into(), vm); + assert_eq!("TypeError", &*translated.unwrap_err().class().name(),); + }) + } +} diff --git a/vm/src/builtins/super.rs b/vm/src/builtins/super.rs index a0192cea5c..5f363ebea5 100644 --- a/vm/src/builtins/super.rs +++ b/vm/src/builtins/super.rs @@ -5,11 +5,11 @@ See also [CPython source code.](https://github.com/python/cpython/blob/50b48572d use super::{PyStr, PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, common::lock::PyRwLock, function::{FuncArgs, IntoFuncArgs, OptionalArg}, types::{Callable, Constructor, GetAttr, GetDescriptor, Initializer, Representable}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "super", traverse)] diff --git a/vm/src/builtins/traceback.rs b/vm/src/builtins/traceback.rs index 6506e95363..b0abbd006a 100644 --- a/vm/src/builtins/traceback.rs +++ b/vm/src/builtins/traceback.rs @@ -2,7 +2,7 @@ use rustpython_common::lock::PyMutex; use super::PyType; use crate::{ - class::PyClassImpl, frame::FrameRef, source_code::LineNumber, Context, Py, PyPayload, PyRef, + Context, Py, PyPayload, PyRef, class::PyClassImpl, frame::FrameRef, source_code::LineNumber, }; #[pyclass(module = false, name = "traceback", traverse)] diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 9d1cc2f5ce..66cb2799a7 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -2,6 +2,7 @@ use super::{PositionIterInternal, PyGenericAlias, PyStrRef, PyType, PyTypeRef}; use crate::common::{hash::PyHash, lock::PyMutex}; use crate::object::{Traverse, TraverseFn}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, atomic_func, class::PyClassImpl, convert::{ToPyObject, TransmuteFromObject}, @@ -17,7 +18,6 @@ use crate::{ }, utils::collection_repr, vm::VirtualMachine, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, }; use once_cell::sync::Lazy; use std::{fmt, marker::PhantomData}; diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 15df5ff3c5..08a15575ed 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -1,16 +1,18 @@ use super::{ - mappingproxy::PyMappingProxy, object, union_, PyClassMethod, PyDictRef, PyList, PyStr, - PyStrInterned, PyStrRef, PyTuple, PyTupleRef, PyWeak, + PyClassMethod, PyDictRef, PyList, PyStr, PyStrInterned, PyStrRef, PyTuple, PyTupleRef, PyWeak, + mappingproxy::PyMappingProxy, object, union_, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, builtins::{ + PyBaseExceptionRef, descriptor::{ MemberGetter, MemberKind, MemberSetter, PyDescriptorOwned, PyMemberDef, PyMemberDescriptor, }, function::PyCellRef, tuple::{IntoPyTuple, PyTupleTyped}, - PyBaseExceptionRef, }, class::{PyClassImpl, StaticType}, common::{ @@ -26,10 +28,8 @@ use crate::{ types::{ AsNumber, Callable, Constructor, GetAttr, PyTypeFlags, PyTypeSlots, Representable, SetAttr, }, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - VirtualMachine, }; -use indexmap::{map::Entry, IndexMap}; +use indexmap::{IndexMap, map::Entry}; use itertools::Itertools; use std::{borrow::Borrow, collections::HashSet, fmt, ops::Deref, pin::Pin, ptr::NonNull}; @@ -69,7 +69,7 @@ pub struct PointerSlot(NonNull); impl PointerSlot { pub unsafe fn borrow_static(&self) -> &'static T { - self.0.as_ref() + unsafe { self.0.as_ref() } } } diff --git a/vm/src/builtins/union.rs b/vm/src/builtins/union.rs index 668d87bdce..f9dc8f3131 100644 --- a/vm/src/builtins/union.rs +++ b/vm/src/builtins/union.rs @@ -1,5 +1,6 @@ use super::{genericalias, type_}; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, builtins::{PyFrozenSet, PyStr, PyTuple, PyTupleRef, PyType}, class::PyClassImpl, @@ -8,7 +9,6 @@ use crate::{ function::PyComparisonValue, protocol::{PyMappingMethods, PyNumberMethods}, types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use once_cell::sync::Lazy; use std::fmt; diff --git a/vm/src/builtins/weakproxy.rs b/vm/src/builtins/weakproxy.rs index f271e4cddb..d17bc75118 100644 --- a/vm/src/builtins/weakproxy.rs +++ b/vm/src/builtins/weakproxy.rs @@ -1,6 +1,6 @@ use super::{PyStr, PyStrRef, PyType, PyTypeRef, PyWeak}; use crate::{ - atomic_func, + Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, class::PyClassImpl, common::hash::PyHash, function::{OptionalArg, PyComparisonValue, PySetterValue}, @@ -10,7 +10,6 @@ use crate::{ AsMapping, AsSequence, Comparable, Constructor, GetAttr, Hashable, IterNext, Iterable, PyComparisonOp, Representable, SetAttr, }, - Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use once_cell::sync::Lazy; diff --git a/vm/src/builtins/weakref.rs b/vm/src/builtins/weakref.rs index 1d52225a26..9b2f248aa9 100644 --- a/vm/src/builtins/weakref.rs +++ b/vm/src/builtins/weakref.rs @@ -4,10 +4,10 @@ use crate::common::{ hash::{self, PyHash}, }; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, class::PyClassImpl, function::OptionalArg, types::{Callable, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; pub use crate::object::PyWeak; diff --git a/vm/src/builtins/zip.rs b/vm/src/builtins/zip.rs index 56c88f14c6..abd82b3ccb 100644 --- a/vm/src/builtins/zip.rs +++ b/vm/src/builtins/zip.rs @@ -1,11 +1,11 @@ use super::{PyType, PyTypeRef}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::PyTupleRef, class::PyClassImpl, function::{ArgIntoBool, OptionalArg, PosArgs}, protocol::{PyIter, PyIterReturn}, types::{Constructor, IterNext, Iterable, SelfIter}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use rustpython_common::atomic::{self, PyAtomic, Radium}; diff --git a/vm/src/bytesinner.rs b/vm/src/bytesinner.rs index 79d4d96262..3754b75ee1 100644 --- a/vm/src/bytesinner.rs +++ b/vm/src/bytesinner.rs @@ -1,8 +1,9 @@ use crate::{ + AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, anystr::{self, AnyStr, AnyStrContainer, AnyStrWrapper}, builtins::{ - pystr, PyBaseExceptionRef, PyByteArray, PyBytes, PyBytesRef, PyInt, PyIntRef, PyStr, - PyStrRef, PyTypeRef, + PyBaseExceptionRef, PyByteArray, PyBytes, PyBytesRef, PyInt, PyIntRef, PyStr, PyStrRef, + PyTypeRef, pystr, }, byte::bytes_from_object, cformat::cformat_bytes, @@ -13,7 +14,6 @@ use crate::{ protocol::PyBuffer, sequence::{SequenceExt, SequenceMutExt}, types::PyComparisonOp, - AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, }; use bstr::ByteSlice; use itertools::Itertools; diff --git a/vm/src/cformat.rs b/vm/src/cformat.rs index 79f4d68833..6e14034d0b 100644 --- a/vm/src/cformat.rs +++ b/vm/src/cformat.rs @@ -2,13 +2,13 @@ //! as per the [Python Docs](https://docs.python.org/3/library/stdtypes.html#printf-style-string-formatting). use crate::{ + AsObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, VirtualMachine, builtins::{ - try_f64_to_bigint, tuple, PyBaseExceptionRef, PyByteArray, PyBytes, PyFloat, PyInt, PyStr, + PyBaseExceptionRef, PyByteArray, PyBytes, PyFloat, PyInt, PyStr, try_f64_to_bigint, tuple, }, function::ArgIntoFloat, protocol::PyBuffer, stdlib::builtins, - AsObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, VirtualMachine, }; use itertools::Itertools; use num_traits::cast::ToPrimitive; diff --git a/vm/src/class.rs b/vm/src/class.rs index 3b7f1bffcb..bc38d6bd61 100644 --- a/vm/src/class.rs +++ b/vm/src/class.rs @@ -5,7 +5,7 @@ use crate::{ function::PyMethodDef, identifier, object::Py, - types::{hash_not_implemented, PyTypeFlags, PyTypeSlots}, + types::{PyTypeFlags, PyTypeSlots, hash_not_implemented}, vm::Context, }; use rustpython_common::static_cell; diff --git a/vm/src/codecs.rs b/vm/src/codecs.rs index ff7bc48915..e104097413 100644 --- a/vm/src/codecs.rs +++ b/vm/src/codecs.rs @@ -1,9 +1,9 @@ use crate::{ + AsObject, Context, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef}, common::{ascii, lock::PyRwLock}, convert::ToPyObject, function::PyMethodDef, - AsObject, Context, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; use std::{borrow::Cow, collections::HashMap, fmt::Write, ops::Range}; @@ -619,11 +619,16 @@ fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyOb // Not supported, fail with original exception return Err(err.downcast().unwrap()); } + + debug_assert!(range.start <= 0.max(s.len() - 1)); + debug_assert!(range.end >= 1.min(s.len())); + debug_assert!(range.end <= s.len()); + let mut c: u32 = 0; // Try decoding a single surrogate character. If there are more, // let the codec call us again. let p = &s.as_bytes()[range.start..]; - if p.len() - range.start >= byte_length { + if p.len().overflowing_sub(range.start).0 >= byte_length { match standard_encoding { StandardEncoding::Utf8 => { if (p[0] as u32 & 0xf0) == 0xe0 @@ -637,10 +642,10 @@ fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyOb } } StandardEncoding::Utf16Le => { - c = (p[1] as u32) << 8 | p[0] as u32; + c = ((p[1] as u32) << 8) | p[0] as u32; } StandardEncoding::Utf16Be => { - c = (p[0] as u32) << 8 | p[1] as u32; + c = ((p[0] as u32) << 8) | p[1] as u32; } StandardEncoding::Utf32Le => { c = ((p[3] as u32) << 24) diff --git a/vm/src/convert/to_pyobject.rs b/vm/src/convert/to_pyobject.rs index c61296fe69..840c0c65fb 100644 --- a/vm/src/convert/to_pyobject.rs +++ b/vm/src/convert/to_pyobject.rs @@ -1,4 +1,4 @@ -use crate::{builtins::PyBaseExceptionRef, PyObjectRef, PyResult, VirtualMachine}; +use crate::{PyObjectRef, PyResult, VirtualMachine, builtins::PyBaseExceptionRef}; /// Implemented by any type that can be returned from a built-in Python function. /// diff --git a/vm/src/convert/try_from.rs b/vm/src/convert/try_from.rs index 8abe829803..941e1fef2a 100644 --- a/vm/src/convert/try_from.rs +++ b/vm/src/convert/try_from.rs @@ -1,7 +1,7 @@ use crate::{ + Py, VirtualMachine, builtins::PyFloat, object::{AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult}, - Py, VirtualMachine, }; use num_traits::ToPrimitive; diff --git a/vm/src/coroutine.rs b/vm/src/coroutine.rs index 6d6b743310..56eb520b2c 100644 --- a/vm/src/coroutine.rs +++ b/vm/src/coroutine.rs @@ -1,9 +1,9 @@ use crate::{ + AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyStrRef}, common::lock::PyMutex, frame::{ExecutionResult, FrameRef}, protocol::PyIterReturn, - AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; @@ -36,8 +36,8 @@ pub struct Coro { exception: PyMutex>, // exc_state } -fn gen_name(gen: &PyObject, vm: &VirtualMachine) -> &'static str { - let typ = gen.class(); +fn gen_name(jen: &PyObject, vm: &VirtualMachine) -> &'static str { + let typ = jen.class(); if typ.is(vm.ctx.types.coroutine_type) { "coroutine" } else if typ.is(vm.ctx.types.async_generator) { @@ -67,7 +67,7 @@ impl Coro { fn run_with_context( &self, - gen: &PyObject, + jen: &PyObject, vm: &VirtualMachine, func: F, ) -> PyResult @@ -75,7 +75,7 @@ impl Coro { F: FnOnce(FrameRef) -> PyResult, { if self.running.compare_exchange(false, true).is_err() { - return Err(vm.new_value_error(format!("{} already executing", gen_name(gen, vm)))); + return Err(vm.new_value_error(format!("{} already executing", gen_name(jen, vm)))); } vm.push_exception(self.exception.lock().take()); @@ -90,7 +90,7 @@ impl Coro { pub fn send( &self, - gen: &PyObject, + jen: &PyObject, value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { @@ -102,22 +102,22 @@ impl Coro { } else if !vm.is_none(&value) { return Err(vm.new_type_error(format!( "can't send non-None value to a just-started {}", - gen_name(gen, vm), + gen_name(jen, vm), ))); } else { None }; - let result = self.run_with_context(gen, vm, |f| f.resume(value, vm)); + let result = self.run_with_context(jen, vm, |f| f.resume(value, vm)); self.maybe_close(&result); match result { Ok(exec_res) => Ok(exec_res.into_iter_return(vm)), Err(e) => { if e.fast_isinstance(vm.ctx.exceptions.stop_iteration) { let err = - vm.new_runtime_error(format!("{} raised StopIteration", gen_name(gen, vm))); + vm.new_runtime_error(format!("{} raised StopIteration", gen_name(jen, vm))); err.set_cause(Some(e)); Err(err) - } else if gen.class().is(vm.ctx.types.async_generator) + } else if jen.class().is(vm.ctx.types.async_generator) && e.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) { let err = vm @@ -132,7 +132,7 @@ impl Coro { } pub fn throw( &self, - gen: &PyObject, + jen: &PyObject, exc_type: PyObjectRef, exc_val: PyObjectRef, exc_tb: PyObjectRef, @@ -141,16 +141,16 @@ impl Coro { if self.closed.load() { return Err(vm.normalize_exception(exc_type, exc_val, exc_tb)?); } - let result = self.run_with_context(gen, vm, |f| f.gen_throw(vm, exc_type, exc_val, exc_tb)); + let result = self.run_with_context(jen, vm, |f| f.gen_throw(vm, exc_type, exc_val, exc_tb)); self.maybe_close(&result); Ok(result?.into_iter_return(vm)) } - pub fn close(&self, gen: &PyObject, vm: &VirtualMachine) -> PyResult<()> { + pub fn close(&self, jen: &PyObject, vm: &VirtualMachine) -> PyResult<()> { if self.closed.load() { return Ok(()); } - let result = self.run_with_context(gen, vm, |f| { + let result = self.run_with_context(jen, vm, |f| { f.gen_throw( vm, vm.ctx.exceptions.generator_exit.to_owned().into(), @@ -161,7 +161,7 @@ impl Coro { self.closed.store(true); match result { Ok(ExecutionResult::Yield(_)) => { - Err(vm.new_runtime_error(format!("{} ignored GeneratorExit", gen_name(gen, vm)))) + Err(vm.new_runtime_error(format!("{} ignored GeneratorExit", gen_name(jen, vm)))) } Err(e) if !is_gen_exit(&e, vm) => Err(e), _ => Ok(()), @@ -183,10 +183,10 @@ impl Coro { pub fn set_name(&self, name: PyStrRef) { *self.name.lock() = name; } - pub fn repr(&self, gen: &PyObject, id: usize, vm: &VirtualMachine) -> String { + pub fn repr(&self, jen: &PyObject, id: usize, vm: &VirtualMachine) -> String { format!( "<{} object {} at {:#x}>", - gen_name(gen, vm), + gen_name(jen, vm), self.name.lock(), id ) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 4baeef0bfc..7c8fd23834 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -4,9 +4,9 @@ //! And: http://code.activestate.com/recipes/578375/ use crate::{ + AsObject, Py, PyExact, PyObject, PyObjectRef, PyRefExact, PyResult, VirtualMachine, builtins::{PyInt, PyStr, PyStrInterned, PyStrRef}, convert::ToPyObject, - AsObject, Py, PyExact, PyObject, PyObjectRef, PyRefExact, PyResult, VirtualMachine, }; use crate::{ common::{ @@ -915,7 +915,7 @@ fn str_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a PyStr> { #[cfg(test)] mod tests { use super::*; - use crate::{common::ascii, Interpreter}; + use crate::{Interpreter, common::ascii}; #[test] fn test_insert() { diff --git a/vm/src/eval.rs b/vm/src/eval.rs index 35f27dc9d6..4c48efc700 100644 --- a/vm/src/eval.rs +++ b/vm/src/eval.rs @@ -1,4 +1,4 @@ -use crate::{compiler, scope::Scope, PyResult, VirtualMachine}; +use crate::{PyResult, VirtualMachine, compiler, scope::Scope}; pub fn eval(vm: &VirtualMachine, source: &str, scope: Scope, source_path: &str) -> PyResult { match vm.compile(source, compiler::Mode::Eval, source_path.to_owned()) { diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 0b61c90174..5882fbe2bc 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -2,8 +2,9 @@ use self::types::{PyBaseException, PyBaseExceptionRef}; use crate::common::lock::PyRwLock; use crate::object::{Traverse, TraverseFn}; use crate::{ + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - traceback::PyTracebackRef, PyNone, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, + PyNone, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, traceback::PyTracebackRef, }, class::{PyClassImpl, StaticType}, convert::{ToPyException, ToPyObject}, @@ -12,7 +13,6 @@ use crate::{ stdlib::sys, suggestion::offer_suggestions, types::{Callable, Constructor, Initializer, Representable}, - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; @@ -374,7 +374,9 @@ fn write_traceback_entry( writeln!( output, r##" File "{}", line {}, in {}"##, - filename, tb_entry.lineno, tb_entry.frame.code.obj_name + filename.trim_start_matches(r"\\?\"), + tb_entry.lineno, + tb_entry.frame.code.obj_name )?; print_source_line(output, filename, tb_entry.lineno.to_usize())?; @@ -1175,13 +1177,13 @@ pub(super) mod types { use crate::common::lock::PyRwLock; #[cfg_attr(target_arch = "wasm32", allow(unused_imports))] use crate::{ + AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, builtins::{ - traceback::PyTracebackRef, tuple::IntoPyTuple, PyInt, PyStrRef, PyTupleRef, PyTypeRef, + PyInt, PyStrRef, PyTupleRef, PyTypeRef, traceback::PyTracebackRef, tuple::IntoPyTuple, }, convert::ToPyResult, function::FuncArgs, types::{Constructor, Initializer}, - AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; diff --git a/vm/src/format.rs b/vm/src/format.rs index 8109ea00f4..7e9bb54265 100644 --- a/vm/src/format.rs +++ b/vm/src/format.rs @@ -1,9 +1,9 @@ use crate::{ + PyObject, PyResult, VirtualMachine, builtins::PyBaseExceptionRef, convert::{IntoPyException, ToPyException}, function::FuncArgs, stdlib::builtins, - PyObject, PyResult, VirtualMachine, }; use rustpython_format::*; diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 8fc7e171b3..7cbe25909b 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1,11 +1,12 @@ use crate::common::{boxvec::BoxVec, lock::PyMutex}; use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ + PyBaseExceptionRef, PyCode, PyCoroutine, PyDict, PyDictRef, PyGenerator, PyList, PySet, + PySlice, PyStr, PyStrInterned, PyStrRef, PyTraceback, PyType, asyncgenerator::PyAsyncGenWrappedValue, function::{PyCell, PyCellRef, PyFunction}, tuple::{PyTuple, PyTupleRef, PyTupleTyped}, - PyBaseExceptionRef, PyCode, PyCoroutine, PyDict, PyDictRef, PyGenerator, PyList, PySet, - PySlice, PyStr, PyStrInterned, PyStrRef, PyTraceback, PyType, }, bytecode, convert::{IntoObject, ToPyResult}, @@ -17,7 +18,6 @@ use crate::{ source_code::SourceLocation, stdlib::{builtins, typing::_typing}, vm::{Context, PyMethod}, - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use indexmap::IndexMap; use itertools::Itertools; @@ -277,15 +277,17 @@ impl Py { } pub fn next_external_frame(&self, vm: &VirtualMachine) -> Option { - self.f_back(vm).map(|mut back| loop { - back = if let Some(back) = back.to_owned().f_back(vm) { - back - } else { - break back; - }; + self.f_back(vm).map(|mut back| { + loop { + back = if let Some(back) = back.to_owned().f_back(vm) { + back + } else { + break back; + }; - if !back.is_internal_frame() { - break back; + if !back.is_internal_frame() { + break back; + } } }) } @@ -426,19 +428,19 @@ impl ExecutingFrame<'_> { exc_val: PyObjectRef, exc_tb: PyObjectRef, ) -> PyResult { - if let Some(gen) = self.yield_from_target() { + if let Some(jen) = self.yield_from_target() { // borrow checker shenanigans - we only need to use exc_type/val/tb if the following // variable is Some - let thrower = if let Some(coro) = self.builtin_coro(gen) { + let thrower = if let Some(coro) = self.builtin_coro(jen) { Some(Either::A(coro)) } else { - vm.get_attribute_opt(gen.to_owned(), "throw")? + vm.get_attribute_opt(jen.to_owned(), "throw")? .map(Either::B) }; if let Some(thrower) = thrower { let ret = match thrower { Either::A(coro) => coro - .throw(gen, exc_type, exc_val, exc_tb, vm) + .throw(jen, exc_type, exc_val, exc_tb, vm) .to_pyresult(vm), // FIXME: Either::B(meth) => meth.call((exc_type, exc_val, exc_tb), vm), }; @@ -638,7 +640,7 @@ impl ExecutingFrame<'_> { match res { Ok(()) => {} Err(e) if e.fast_isinstance(vm.ctx.exceptions.key_error) => { - return Err(name_error(name, vm)) + return Err(name_error(name, vm)); } Err(e) => return Err(e), } @@ -649,7 +651,7 @@ impl ExecutingFrame<'_> { match self.globals.del_item(name, vm) { Ok(()) => {} Err(e) if e.fast_isinstance(vm.ctx.exceptions.key_error) => { - return Err(name_error(name, vm)) + return Err(name_error(name, vm)); } Err(e) => return Err(e), } @@ -1202,6 +1204,23 @@ impl ExecutingFrame<'_> { self.push_value(type_alias.into_ref(&vm.ctx).into()); Ok(None) } + bytecode::Instruction::ParamSpec => { + let param_spec_name = self.pop_value(); + let param_spec: PyObjectRef = _typing::make_paramspec(param_spec_name.clone()) + .into_ref(&vm.ctx) + .into(); + self.push_value(param_spec); + Ok(None) + } + bytecode::Instruction::TypeVarTuple => { + let type_var_tuple_name = self.pop_value(); + let type_var_tuple: PyObjectRef = + _typing::make_typevartuple(type_var_tuple_name.clone()) + .into_ref(&vm.ctx) + .into(); + self.push_value(type_var_tuple); + Ok(None) + } } } @@ -1551,16 +1570,16 @@ impl ExecutingFrame<'_> { fn _send( &self, - gen: &PyObject, + jen: &PyObject, val: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - match self.builtin_coro(gen) { - Some(coro) => coro.send(gen, val, vm), + match self.builtin_coro(jen) { + Some(coro) => coro.send(jen, val, vm), // FIXME: turn return type to PyResult then ExecutionResult will be simplified - None if vm.is_none(&val) => PyIter::new(gen).next(vm), + None if vm.is_none(&val) => PyIter::new(jen).next(vm), None => { - let meth = gen.get_attr("send", vm)?; + let meth = jen.get_attr("send", vm)?; PyIterReturn::from_pyresult(meth.call((val,), vm), vm) } } diff --git a/vm/src/function/argument.rs b/vm/src/function/argument.rs index bc60cfa253..b7fd509ef1 100644 --- a/vm/src/function/argument.rs +++ b/vm/src/function/argument.rs @@ -1,8 +1,8 @@ use crate::{ + AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyTupleRef, PyTypeRef}, convert::ToPyObject, object::{Traverse, TraverseFn}, - AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use indexmap::IndexMap; use itertools::Itertools; @@ -497,7 +497,7 @@ where { fn traverse(&self, tracer_fn: &mut TraverseFn) { match self { - OptionalArg::Present(ref o) => o.traverse(tracer_fn), + OptionalArg::Present(o) => o.traverse(tracer_fn), OptionalArg::Missing => (), } } diff --git a/vm/src/function/arithmetic.rs b/vm/src/function/arithmetic.rs index b40e31e1b6..9f40ca7fec 100644 --- a/vm/src/function/arithmetic.rs +++ b/vm/src/function/arithmetic.rs @@ -1,7 +1,7 @@ use crate::{ + VirtualMachine, convert::{ToPyObject, TryFromObject}, object::{AsObject, PyObjectRef, PyResult}, - VirtualMachine, }; #[derive(result_like::OptionLike)] diff --git a/vm/src/function/buffer.rs b/vm/src/function/buffer.rs index f5d0dd03d6..91379e7a7f 100644 --- a/vm/src/function/buffer.rs +++ b/vm/src/function/buffer.rs @@ -1,9 +1,9 @@ use crate::{ + AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, + VirtualMachine, builtins::{PyStr, PyStrRef}, common::borrow::{BorrowedValue, BorrowedValueMut}, protocol::PyBuffer, - AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, - VirtualMachine, }; // Python/getargs.c diff --git a/vm/src/function/builtin.rs b/vm/src/function/builtin.rs index bf9a5ed345..b8a408453d 100644 --- a/vm/src/function/builtin.rs +++ b/vm/src/function/builtin.rs @@ -1,7 +1,7 @@ use super::{FromArgs, FuncArgs}; use crate::{ - convert::ToPyResult, object::PyThreadingConstraint, Py, PyPayload, PyRef, PyResult, - VirtualMachine, + Py, PyPayload, PyRef, PyResult, VirtualMachine, convert::ToPyResult, + object::PyThreadingConstraint, }; use std::marker::PhantomData; diff --git a/vm/src/function/either.rs b/vm/src/function/either.rs index ceb79d55c9..08b96c7fe3 100644 --- a/vm/src/function/either.rs +++ b/vm/src/function/either.rs @@ -1,5 +1,5 @@ use crate::{ - convert::ToPyObject, AsObject, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, + AsObject, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, convert::ToPyObject, }; use std::borrow::Borrow; diff --git a/vm/src/function/fspath.rs b/vm/src/function/fspath.rs index 41e99b0542..69f11eb65d 100644 --- a/vm/src/function/fspath.rs +++ b/vm/src/function/fspath.rs @@ -1,9 +1,9 @@ use crate::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBytes, PyBytesRef, PyStrRef}, convert::{IntoPyException, ToPyObject}, function::PyStr, protocol::PyBuffer, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use std::{ffi::OsStr, path::PathBuf}; diff --git a/vm/src/function/getset.rs b/vm/src/function/getset.rs index 827158e834..66e668ace6 100644 --- a/vm/src/function/getset.rs +++ b/vm/src/function/getset.rs @@ -2,10 +2,10 @@ */ use crate::{ + Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, convert::ToPyResult, function::{BorrowedParam, OwnedParam, RefParam}, object::PyThreadingConstraint, - Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; #[derive(result_like::OptionLike, is_macro::Is, Debug)] diff --git a/vm/src/function/method.rs b/vm/src/function/method.rs index c8652ffb50..d3d0b85fae 100644 --- a/vm/src/function/method.rs +++ b/vm/src/function/method.rs @@ -1,11 +1,11 @@ use crate::{ + Context, Py, PyObjectRef, PyPayload, PyRef, VirtualMachine, builtins::{ + PyType, builtin_func::{PyNativeFunction, PyNativeMethod}, descriptor::PyMethodDescriptor, - PyType, }, function::{IntoPyNativeFn, PyNativeFn}, - Context, Py, PyObjectRef, PyPayload, PyRef, VirtualMachine, }; bitflags::bitflags! { @@ -274,7 +274,7 @@ impl HeapMethodDef { impl Py { pub(crate) unsafe fn method(&self) -> &'static PyMethodDef { - &*(&self.method as *const _) + unsafe { &*(&self.method as *const _) } } pub fn build_function(&self, vm: &VirtualMachine) -> PyRef { diff --git a/vm/src/function/mod.rs b/vm/src/function/mod.rs index 3bd6d0f74c..8e517f6ed5 100644 --- a/vm/src/function/mod.rs +++ b/vm/src/function/mod.rs @@ -15,7 +15,7 @@ pub use argument::{ }; pub use arithmetic::{PyArithmeticValue, PyComparisonValue}; pub use buffer::{ArgAsciiBuffer, ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike}; -pub use builtin::{static_func, static_raw_func, IntoPyNativeFn, PyNativeFn}; +pub use builtin::{IntoPyNativeFn, PyNativeFn, static_func, static_raw_func}; pub use either::Either; pub use fspath::FsPath; pub use getset::PySetterValue; @@ -24,7 +24,7 @@ pub use method::{HeapMethodDef, PyMethodDef, PyMethodFlags}; pub use number::{ArgIndex, ArgIntoBool, ArgIntoComplex, ArgIntoFloat, ArgPrimitiveIndex, ArgSize}; pub use protocol::{ArgCallable, ArgIterable, ArgMapping, ArgSequence}; -use crate::{builtins::PyStr, convert::TryFromBorrowedObject, PyObject, PyResult, VirtualMachine}; +use crate::{PyObject, PyResult, VirtualMachine, builtins::PyStr, convert::TryFromBorrowedObject}; use builtin::{BorrowedParam, OwnedParam, RefParam}; #[derive(Clone, Copy, PartialEq, Eq)] diff --git a/vm/src/function/number.rs b/vm/src/function/number.rs index 5f23543395..0e36f57ad1 100644 --- a/vm/src/function/number.rs +++ b/vm/src/function/number.rs @@ -1,5 +1,5 @@ use super::argument::OptionalArg; -use crate::{builtins::PyIntRef, AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine}; +use crate::{AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::PyIntRef}; use malachite_bigint::BigInt; use num_complex::Complex64; use num_traits::PrimInt; diff --git a/vm/src/function/protocol.rs b/vm/src/function/protocol.rs index 295332e480..4b7e4c4cef 100644 --- a/vm/src/function/protocol.rs +++ b/vm/src/function/protocol.rs @@ -1,12 +1,12 @@ use super::IntoFuncArgs; use crate::{ - builtins::{iter::PySequenceIterator, PyDict, PyDictRef}, + AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, + builtins::{PyDict, PyDictRef, iter::PySequenceIterator}, convert::ToPyObject, identifier, object::{Traverse, TraverseFn}, protocol::{PyIter, PyIterIter, PyMapping, PyMappingMethods}, types::{AsMapping, GenericMethod}, - AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; use std::{borrow::Borrow, marker::PhantomData, ops::Deref}; diff --git a/vm/src/import.rs b/vm/src/import.rs index f14ae31ef4..860f0b8a16 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -2,13 +2,12 @@ * Import mechanics */ use crate::{ - builtins::{list, traceback::PyTraceback, PyBaseExceptionRef, PyCode}, + AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + builtins::{PyBaseExceptionRef, PyCode, list, traceback::PyTraceback}, scope::Scope, version::get_git_revision, - vm::{thread, VirtualMachine}, - AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + vm::{VirtualMachine, thread}, }; -use rand::Rng; pub(crate) fn init_importlib_base(vm: &mut VirtualMachine) -> PyResult { flame_guard!("init importlib"); @@ -50,7 +49,7 @@ pub(crate) fn init_importlib_package(vm: &VirtualMachine, importlib: PyObjectRef let mut magic = get_git_revision().into_bytes(); magic.truncate(4); if magic.len() != 4 { - magic = rand::thread_rng().gen::<[u8; 4]>().to_vec(); + magic = rand::random::<[u8; 4]>().to_vec(); } let magic: PyObjectRef = vm.ctx.new_bytes(magic).into(); importlib_external.set_attr("MAGIC_NUMBER", magic, vm)?; diff --git a/vm/src/intern.rs b/vm/src/intern.rs index 45bd45d965..10aaa53454 100644 --- a/vm/src/intern.rs +++ b/vm/src/intern.rs @@ -1,8 +1,8 @@ use crate::{ + AsObject, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, VirtualMachine, builtins::{PyStr, PyStrInterned, PyTypeRef}, common::lock::PyRwLock, convert::ToPyObject, - AsObject, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, VirtualMachine, }; use std::{ borrow::{Borrow, ToOwned}, @@ -117,7 +117,7 @@ impl CachedPyStrRef { /// the given cache must be alive while returned reference is alive #[inline] unsafe fn as_interned_str(&self) -> &'static PyStrInterned { - std::mem::transmute_copy(self) + unsafe { std::mem::transmute_copy(self) } } #[inline] diff --git a/vm/src/iter.rs b/vm/src/iter.rs index 497dc20adc..1e13243792 100644 --- a/vm/src/iter.rs +++ b/vm/src/iter.rs @@ -1,4 +1,4 @@ -use crate::{types::PyComparisonOp, vm::VirtualMachine, PyObjectRef, PyResult}; +use crate::{PyObjectRef, PyResult, types::PyComparisonOp, vm::VirtualMachine}; use itertools::Itertools; pub trait PyExactSizeIterator<'a>: ExactSizeIterator + Sized { diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index b326935464..481111532a 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -11,9 +11,9 @@ //! PyRef may looking like to be called as PyObjectWeak by the rule, //! but not to do to remember it is a PyRef object. use super::{ + PyAtomicRef, ext::{AsObject, PyRefExact, PyResult}, payload::PyObjectPayload, - PyAtomicRef, }; use crate::object::traverse::{Traverse, TraverseFn}; use crate::object::traverse_object::PyObjVTable; @@ -77,19 +77,19 @@ use std::{ pub(super) struct Erased; pub(super) unsafe fn drop_dealloc_obj(x: *mut PyObject) { - drop(Box::from_raw(x as *mut PyInner)); + drop(unsafe { Box::from_raw(x as *mut PyInner) }); } pub(super) unsafe fn debug_obj( x: &PyObject, f: &mut fmt::Formatter, ) -> fmt::Result { - let x = &*(x as *const PyObject as *const PyInner); + let x = unsafe { &*(x as *const PyObject as *const PyInner) }; fmt::Debug::fmt(x, f) } /// Call `try_trace` on payload pub(super) unsafe fn try_trace_obj(x: &PyObject, tracer_fn: &mut TraverseFn) { - let x = &*(x as *const PyObject as *const PyInner); + let x = unsafe { &*(x as *const PyObject as *const PyInner) }; let payload = &x.payload; payload.try_traverse(tracer_fn) } @@ -278,7 +278,7 @@ impl WeakRefList { } unsafe fn dealloc(ptr: NonNull>) { - drop(Box::from_raw(ptr.as_ptr())); + drop(unsafe { Box::from_raw(ptr.as_ptr()) }); } fn get_weak_references(&self) -> Vec> { @@ -317,12 +317,14 @@ unsafe impl Link for WeakLink { #[inline(always)] unsafe fn from_raw(ptr: NonNull) -> Self::Handle { - PyRef::from_raw(ptr.as_ptr()) + // SAFETY: requirements forwarded from caller + unsafe { PyRef::from_raw(ptr.as_ptr()) } } #[inline(always)] unsafe fn pointers(target: NonNull) -> NonNull> { - NonNull::new_unchecked(&raw mut (*target.as_ptr()).0.payload.pointers) + // SAFETY: requirements forwarded from caller + unsafe { NonNull::new_unchecked(&raw mut (*target.as_ptr()).0.payload.pointers) } } } @@ -352,7 +354,7 @@ impl PyWeak { if !obj_ptr.as_ref().0.ref_count.safe_inc() { return None; } - Some(PyObjectRef::from_raw(obj_ptr.as_ptr())) + Some(PyObjectRef::from_raw(obj_ptr)) } } @@ -506,8 +508,8 @@ impl ToOwned for PyObject { impl PyObjectRef { #[inline(always)] - pub fn into_raw(self) -> *const PyObject { - let ptr = self.as_raw(); + pub fn into_raw(self) -> NonNull { + let ptr = self.ptr; std::mem::forget(self); ptr } @@ -518,10 +520,8 @@ impl PyObjectRef { /// dropped more than once due to mishandling the reference count by calling this function /// too many times. #[inline(always)] - pub unsafe fn from_raw(ptr: *const PyObject) -> Self { - Self { - ptr: NonNull::new_unchecked(ptr as *mut PyObject), - } + pub unsafe fn from_raw(ptr: NonNull) -> Self { + Self { ptr } } /// Attempt to downcast this reference to a subclass. @@ -567,7 +567,8 @@ impl PyObjectRef { #[inline(always)] pub unsafe fn downcast_unchecked_ref(&self) -> &Py { debug_assert!(self.payload_is::()); - &*(self as *const PyObjectRef as *const PyRef) + // SAFETY: requirements forwarded from caller + unsafe { &*(self as *const PyObjectRef as *const PyRef) } } // ideally we'd be able to define these in pyobject.rs, but method visibility rules are weird @@ -752,7 +753,8 @@ impl PyObject { #[inline(always)] pub unsafe fn downcast_unchecked_ref(&self) -> &Py { debug_assert!(self.payload_is::()); - &*(self as *const PyObject as *const Py) + // SAFETY: requirements forwarded from caller + unsafe { &*(self as *const PyObject as *const Py) } } #[inline(always)] @@ -814,13 +816,13 @@ impl PyObject { /// Can only be called when ref_count has dropped to zero. `ptr` must be valid #[inline(never)] unsafe fn drop_slow(ptr: NonNull) { - if let Err(()) = ptr.as_ref().drop_slow_inner() { + if let Err(()) = unsafe { ptr.as_ref().drop_slow_inner() } { // abort drop for whatever reason return; } - let drop_dealloc = ptr.as_ref().0.vtable.drop_dealloc; + let drop_dealloc = unsafe { ptr.as_ref().0.vtable.drop_dealloc }; // call drop only when there are no references in scope - stacked borrows stuff - drop_dealloc(ptr.as_ptr()) + unsafe { drop_dealloc(ptr.as_ptr()) } } /// # Safety @@ -1022,7 +1024,7 @@ impl PyRef { #[inline(always)] pub(crate) unsafe fn from_raw(raw: *const Py) -> Self { Self { - ptr: NonNull::new_unchecked(raw as *mut _), + ptr: unsafe { NonNull::new_unchecked(raw as *mut _) }, } } diff --git a/vm/src/object/ext.rs b/vm/src/object/ext.rs index f7247bc5e0..8c6d367583 100644 --- a/vm/src/object/ext.rs +++ b/vm/src/object/ext.rs @@ -7,12 +7,18 @@ use crate::common::{ lock::PyRwLockReadGuard, }; use crate::{ + VirtualMachine, builtins::{PyBaseExceptionRef, PyStrInterned, PyType}, convert::{IntoPyException, ToPyObject, ToPyResult, TryFromObject}, vm::Context, - VirtualMachine, }; -use std::{borrow::Borrow, fmt, marker::PhantomData, ops::Deref, ptr::null_mut}; +use std::{ + borrow::Borrow, + fmt, + marker::PhantomData, + ops::Deref, + ptr::{NonNull, null_mut}, +}; /* Python objects and references. @@ -60,7 +66,7 @@ impl PyExact { /// Given reference must be exact type of payload T #[inline(always)] pub unsafe fn ref_unchecked(r: &Py) -> &Self { - &*(r as *const _ as *const Self) + unsafe { &*(r as *const _ as *const Self) } } } @@ -294,7 +300,7 @@ impl PyAtomicRef { pub unsafe fn swap(&self, pyref: PyRef) -> PyRef { let py = PyRef::leak(pyref) as *const Py as *mut _; let old = Radium::swap(&self.inner, py, Ordering::AcqRel); - PyRef::from_raw(old.cast()) + unsafe { PyRef::from_raw(old.cast()) } } pub fn swap_to_temporary_refs(&self, pyref: PyRef, vm: &VirtualMachine) { @@ -352,7 +358,7 @@ impl From for PyAtomicRef { fn from(obj: PyObjectRef) -> Self { let obj = obj.into_raw(); Self { - inner: Radium::new(obj as *mut _), + inner: Radium::new(obj.cast().as_ptr()), _phantom: Default::default(), } } @@ -379,8 +385,8 @@ impl PyAtomicRef { #[must_use] pub unsafe fn swap(&self, obj: PyObjectRef) -> PyObjectRef { let obj = obj.into_raw(); - let old = Radium::swap(&self.inner, obj as *mut _, Ordering::AcqRel); - PyObjectRef::from_raw(old as _) + let old = Radium::swap(&self.inner, obj.cast().as_ptr(), Ordering::AcqRel); + unsafe { PyObjectRef::from_raw(NonNull::new_unchecked(old.cast())) } } pub fn swap_to_temporary_refs(&self, obj: PyObjectRef, vm: &VirtualMachine) { @@ -393,7 +399,9 @@ impl PyAtomicRef { impl From> for PyAtomicRef> { fn from(obj: Option) -> Self { - let val = obj.map(|x| x.into_raw() as *mut _).unwrap_or(null_mut()); + let val = obj + .map(|x| x.into_raw().as_ptr().cast()) + .unwrap_or(null_mut()); Self { inner: Radium::new(val), _phantom: Default::default(), @@ -420,11 +428,11 @@ impl PyAtomicRef> { /// until no more reference can be used via PyAtomicRef::deref() #[must_use] pub unsafe fn swap(&self, obj: Option) -> Option { - let val = obj.map(|x| x.into_raw() as *mut _).unwrap_or(null_mut()); + let val = obj + .map(|x| x.into_raw().as_ptr().cast()) + .unwrap_or(null_mut()); let old = Radium::swap(&self.inner, val, Ordering::AcqRel); - old.cast::() - .as_ref() - .map(|x| PyObjectRef::from_raw(x)) + unsafe { NonNull::new(old.cast::()).map(|x| PyObjectRef::from_raw(x)) } } pub fn swap_to_temporary_refs(&self, obj: Option, vm: &VirtualMachine) { diff --git a/vm/src/object/payload.rs b/vm/src/object/payload.rs index d5f3d6330b..6413d6ae06 100644 --- a/vm/src/object/payload.rs +++ b/vm/src/object/payload.rs @@ -1,9 +1,9 @@ use crate::object::{MaybeTraverse, Py, PyObjectRef, PyRef, PyResult}; use crate::{ + PyRefExact, builtins::{PyBaseExceptionRef, PyType, PyTypeRef}, types::PyTypeFlags, vm::{Context, VirtualMachine}, - PyRefExact, }; cfg_if::cfg_if! { diff --git a/vm/src/object/traverse.rs b/vm/src/object/traverse.rs index 43a039d331..5f93dc5c8b 100644 --- a/vm/src/object/traverse.rs +++ b/vm/src/object/traverse.rs @@ -2,7 +2,7 @@ use std::ptr::NonNull; use rustpython_common::lock::{PyMutex, PyRwLock}; -use crate::{function::Either, object::PyObjectPayload, AsObject, PyObject, PyObjectRef, PyRef}; +use crate::{AsObject, PyObject, PyObjectRef, PyRef, function::Either, object::PyObjectPayload}; pub type TraverseFn<'a> = dyn FnMut(&PyObject) + 'a; diff --git a/vm/src/object/traverse_object.rs b/vm/src/object/traverse_object.rs index b690b80f75..682ddf4876 100644 --- a/vm/src/object/traverse_object.rs +++ b/vm/src/object/traverse_object.rs @@ -1,10 +1,10 @@ use std::fmt; use crate::{ + PyObject, object::{ - debug_obj, drop_dealloc_obj, try_trace_obj, Erased, InstanceDict, PyInner, PyObjectPayload, + Erased, InstanceDict, PyInner, PyObjectPayload, debug_obj, drop_dealloc_obj, try_trace_obj, }, - PyObject, }; use super::{Traverse, TraverseFn}; diff --git a/vm/src/ospath.rs b/vm/src/ospath.rs index fe26f227d1..9dda60d621 100644 --- a/vm/src/ospath.rs +++ b/vm/src/ospath.rs @@ -1,9 +1,9 @@ use crate::{ + PyObjectRef, PyResult, VirtualMachine, builtins::PyBaseExceptionRef, convert::{ToPyException, TryFromObject}, function::FsPath, object::AsObject, - PyObjectRef, PyResult, VirtualMachine, }; use std::path::{Path, PathBuf}; diff --git a/vm/src/protocol/buffer.rs b/vm/src/protocol/buffer.rs index 3783ccf4b6..8692e4f78e 100644 --- a/vm/src/protocol/buffer.rs +++ b/vm/src/protocol/buffer.rs @@ -2,6 +2,7 @@ //! https://docs.python.org/3/c-api/buffer.html use crate::{ + Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, common::{ borrow::{BorrowedValue, BorrowedValueMut}, lock::{MapImmutable, PyMutex, PyMutexGuard}, @@ -9,7 +10,6 @@ use crate::{ object::PyObjectPayload, sliceable::SequenceIndexOp, types::Unconstructible, - Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; use itertools::Itertools; use std::{borrow::Cow, fmt::Debug, ops::Range}; @@ -133,8 +133,11 @@ impl PyBuffer { // after this function, the owner should use forget() // or wrap PyBuffer in the ManaullyDrop to prevent drop() pub(crate) unsafe fn drop_without_release(&mut self) { - std::ptr::drop_in_place(&mut self.obj); - std::ptr::drop_in_place(&mut self.desc); + // SAFETY: requirements forwarded from caller + unsafe { + std::ptr::drop_in_place(&mut self.obj); + std::ptr::drop_in_place(&mut self.desc); + } } } diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs index d9c49ed15f..345914e411 100644 --- a/vm/src/protocol/iter.rs +++ b/vm/src/protocol/iter.rs @@ -1,8 +1,8 @@ use crate::{ + AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, builtins::iter::PySequenceIterator, convert::{ToPyObject, ToPyResult}, object::{Traverse, TraverseFn}, - AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; use std::borrow::Borrow; use std::ops::Deref; diff --git a/vm/src/protocol/mapping.rs b/vm/src/protocol/mapping.rs index 986b806f6c..cbecc8762e 100644 --- a/vm/src/protocol/mapping.rs +++ b/vm/src/protocol/mapping.rs @@ -1,12 +1,12 @@ use crate::{ + AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine, builtins::{ + PyDict, PyStrInterned, dict::{PyDictItems, PyDictKeys, PyDictValues}, type_::PointerSlot, - PyDict, PyStrInterned, }, convert::ToPyResult, object::{Traverse, TraverseFn}, - AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 01829bb553..65c4eaad79 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -3,13 +3,13 @@ use std::ops::Deref; use crossbeam_utils::atomic::AtomicCell; use crate::{ - builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr}, + AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, + VirtualMachine, + builtins::{PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr, int}, common::int::bytes_to_int, function::ArgBytesLike, object::{Traverse, TraverseFn}, stdlib::warnings, - AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, - VirtualMachine, }; pub type PyNumberUnaryFunc = fn(PyNumber, &VirtualMachine) -> PyResult; diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 8b9c0e446c..4e69cf38a2 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -2,9 +2,10 @@ //! https://docs.python.org/3/c-api/object.html use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - pystr::AsPyStr, PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, - PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, + PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef, + PyTuple, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr, }, bytesinner::ByteInnerNewOptions, common::{hash::PyHash, str::to_ascii}, @@ -14,7 +15,6 @@ use crate::{ object::PyPayload, protocol::{PyIter, PyMapping, PySequence}, types::{Constructor, PyComparisonOp}, - AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; // RustPython doesn't need these items diff --git a/vm/src/protocol/sequence.rs b/vm/src/protocol/sequence.rs index 62cf828e40..46ac0c8be7 100644 --- a/vm/src/protocol/sequence.rs +++ b/vm/src/protocol/sequence.rs @@ -1,10 +1,10 @@ use crate::{ - builtins::{type_::PointerSlot, PyList, PyListRef, PySlice, PyTuple, PyTupleRef}, + PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, + builtins::{PyList, PyListRef, PySlice, PyTuple, PyTupleRef, type_::PointerSlot}, convert::ToPyObject, function::PyArithmeticValue, object::{Traverse, TraverseFn}, protocol::{PyMapping, PyNumberBinaryOp}, - PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; diff --git a/vm/src/py_io.rs b/vm/src/py_io.rs index 8091b75afc..0d7c319bc0 100644 --- a/vm/src/py_io.rs +++ b/vm/src/py_io.rs @@ -1,7 +1,7 @@ use crate::{ + PyObject, PyObjectRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytes, PyStr}, common::ascii, - PyObject, PyObjectRef, PyResult, VirtualMachine, }; use std::{fmt, io, ops}; diff --git a/vm/src/py_serde.rs b/vm/src/py_serde.rs index b4b12e430d..ea72879e42 100644 --- a/vm/src/py_serde.rs +++ b/vm/src/py_serde.rs @@ -3,7 +3,7 @@ use num_traits::sign::Signed; use serde::de::{DeserializeSeed, Visitor}; use serde::ser::{Serialize, SerializeMap, SerializeSeq}; -use crate::builtins::{bool_, dict::PyDictRef, float, int, list::PyList, tuple::PyTuple, PyStr}; +use crate::builtins::{PyStr, bool_, dict::PyDictRef, float, int, list::PyList, tuple::PyTuple}; use crate::{AsObject, PyObject, PyObjectRef, VirtualMachine}; #[inline] diff --git a/vm/src/scope.rs b/vm/src/scope.rs index 47a1e5e3ff..12b878f847 100644 --- a/vm/src/scope.rs +++ b/vm/src/scope.rs @@ -1,4 +1,4 @@ -use crate::{builtins::PyDictRef, function::ArgMapping, VirtualMachine}; +use crate::{VirtualMachine, builtins::PyDictRef, function::ArgMapping}; use std::fmt; #[derive(Clone)] diff --git a/vm/src/sequence.rs b/vm/src/sequence.rs index 3f2d1e94d2..fc6e216809 100644 --- a/vm/src/sequence.rs +++ b/vm/src/sequence.rs @@ -1,10 +1,10 @@ use crate::{ + AsObject, PyObject, PyObjectRef, PyResult, builtins::PyIntRef, function::OptionalArg, sliceable::SequenceIndexOp, types::PyComparisonOp, - vm::{VirtualMachine, MAX_MEMORY_SIZE}, - AsObject, PyObject, PyObjectRef, PyResult, + vm::{MAX_MEMORY_SIZE, VirtualMachine}, }; use optional::Optioned; use std::ops::{Deref, Range}; diff --git a/vm/src/signal.rs b/vm/src/signal.rs index 7489282de2..346664fa24 100644 --- a/vm/src/signal.rs +++ b/vm/src/signal.rs @@ -69,7 +69,7 @@ pub fn assert_in_range(signum: i32, vm: &VirtualMachine) -> PyResult<()> { #[allow(dead_code)] #[cfg(not(target_arch = "wasm32"))] pub fn set_interrupt_ex(signum: i32, vm: &VirtualMachine) -> PyResult<()> { - use crate::stdlib::signal::_signal::{run_signal, SIG_DFL, SIG_IGN}; + use crate::stdlib::signal::_signal::{SIG_DFL, SIG_IGN, run_signal}; assert_in_range(signum, vm)?; match signum as usize { diff --git a/vm/src/sliceable.rs b/vm/src/sliceable.rs index 257d325651..cbc25e4e18 100644 --- a/vm/src/sliceable.rs +++ b/vm/src/sliceable.rs @@ -1,7 +1,7 @@ // export through sliceable module, not slice. use crate::{ - builtins::{int::PyInt, slice::PySlice}, PyObject, PyResult, VirtualMachine, + builtins::{int::PyInt, slice::PySlice}, }; use malachite_bigint::BigInt; use num_traits::{Signed, ToPrimitive}; @@ -357,13 +357,8 @@ impl SaturatedSlice { if step == 0 { return Err(vm.new_value_error("slice step cannot be zero".to_owned())); } - let start = to_isize_index(vm, slice.start_ref(vm))?.unwrap_or_else(|| { - if step.is_negative() { - isize::MAX - } else { - 0 - } - }); + let start = to_isize_index(vm, slice.start_ref(vm))? + .unwrap_or_else(|| if step.is_negative() { isize::MAX } else { 0 }); let stop = to_isize_index(vm, &slice.stop(vm))?.unwrap_or_else(|| { if step.is_negative() { diff --git a/vm/src/stdlib/ast.rs b/vm/src/stdlib/ast.rs index 8b081294b9..7dd893646e 100644 --- a/vm/src/stdlib/ast.rs +++ b/vm/src/stdlib/ast.rs @@ -3,17 +3,17 @@ //! This module makes use of the parser logic, and translates all ast nodes //! into python ast.AST objects. -mod gen; +mod r#gen; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, builtins::{self, PyDict, PyModule, PyStrRef, PyType}, class::{PyClassImpl, StaticType}, - compiler::core::bytecode::OpArgType, compiler::CompileError, + compiler::core::bytecode::OpArgType, convert::ToPyException, source_code::{LinearLocator, OneIndexed, SourceLocation, SourceRange}, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - VirtualMachine, }; use num_complex::Complex64; use num_traits::{ToPrimitive, Zero}; @@ -26,9 +26,9 @@ use rustpython_parser as parser; #[pymodule] mod _ast { use crate::{ + AsObject, Context, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyStrRef, PyTupleRef}, function::FuncArgs, - AsObject, Context, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; #[pyattr] #[pyclass(module = "_ast", name = "AST")] @@ -398,6 +398,6 @@ pub const PY_COMPILE_FLAGS_MASK: i32 = PY_COMPILE_FLAG_AST_ONLY pub fn make_module(vm: &VirtualMachine) -> PyRef { let module = _ast::make_module(vm); - gen::extend_module_nodes(vm, &module); + r#gen::extend_module_nodes(vm, &module); module } diff --git a/vm/src/stdlib/atexit.rs b/vm/src/stdlib/atexit.rs index dbeda76741..b1832b5481 100644 --- a/vm/src/stdlib/atexit.rs +++ b/vm/src/stdlib/atexit.rs @@ -3,7 +3,7 @@ pub(crate) use atexit::make_module; #[pymodule] mod atexit { - use crate::{function::FuncArgs, AsObject, PyObjectRef, PyResult, VirtualMachine}; + use crate::{AsObject, PyObjectRef, PyResult, VirtualMachine, function::FuncArgs}; #[pyfunction] fn register(func: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyObjectRef { diff --git a/vm/src/stdlib/builtins.rs b/vm/src/stdlib/builtins.rs index cc9b0a8be7..22aa11dc43 100644 --- a/vm/src/stdlib/builtins.rs +++ b/vm/src/stdlib/builtins.rs @@ -1,7 +1,7 @@ //! Builtin function definitions. //! //! Implements the list of [builtin Python functions](https://docs.python.org/3/library/builtins.html). -use crate::{builtins::PyModule, class::PyClassImpl, Py, VirtualMachine}; +use crate::{Py, VirtualMachine, builtins::PyModule, class::PyClassImpl}; pub(crate) use builtins::{__module_def, DOC}; pub use builtins::{ascii, print, reversed}; @@ -10,13 +10,14 @@ mod builtins { use std::io::IsTerminal; use crate::{ + AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ + PyByteArray, PyBytes, PyDictRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, enumerate::PyReverseSequenceIterator, function::{PyCellRef, PyFunction}, int::PyIntRef, iter::PyCallableIterator, list::{PyList, SortOptions}, - PyByteArray, PyBytes, PyDictRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, }, common::{hash::PyHash, str::to_ascii}, function::{ @@ -29,7 +30,6 @@ mod builtins { readline::{Readline, ReadlineResult}, stdlib::sys, types::PyComparisonOp, - AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use num_traits::{Signed, ToPrimitive}; @@ -157,8 +157,7 @@ mod builtins { #[cfg(not(feature = "rustpython-parser"))] { - const PARSER_NOT_SUPPORTED: &str = - "can't compile() source code when the `parser` feature of rustpython is disabled"; + const PARSER_NOT_SUPPORTED: &str = "can't compile() source code when the `parser` feature of rustpython is disabled"; Err(vm.new_type_error(PARSER_NOT_SUPPORTED.to_owned())) } #[cfg(feature = "rustpython-parser")] @@ -535,7 +534,7 @@ mod builtins { None => { return default.ok_or_else(|| { vm.new_value_error(format!("{func_name}() arg is an empty sequence")) - }) + }); } }; diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 71317bf484..976545f64b 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -4,10 +4,10 @@ pub(crate) use _codecs::make_module; mod _codecs { use crate::common::encodings; use crate::{ + AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple}, codecs, function::{ArgBytesLike, FuncArgs}, - AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; use std::ops::Range; diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 5a9a172d53..fc867db2b1 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -3,6 +3,7 @@ pub(crate) use _collections::make_module; #[pymodule] mod _collections { use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, builtins::{ IterStatus::{Active, Exhausted}, @@ -20,7 +21,6 @@ mod _collections { Iterable, PyComparisonOp, Representable, SelfIter, }, utils::collection_repr, - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use std::cmp::max; diff --git a/vm/src/stdlib/ctypes.rs b/vm/src/stdlib/ctypes.rs new file mode 100644 index 0000000000..2580939b62 --- /dev/null +++ b/vm/src/stdlib/ctypes.rs @@ -0,0 +1,226 @@ +pub(crate) mod array; +pub(crate) mod base; +pub(crate) mod function; +pub(crate) mod library; +pub(crate) mod pointer; +pub(crate) mod structure; +pub(crate) mod union; + +use crate::builtins::PyModule; +use crate::class::PyClassImpl; +use crate::stdlib::ctypes::base::{PyCData, PyCSimple, PySimpleMeta}; +use crate::{Py, PyRef, VirtualMachine}; + +pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { + let ctx = &vm.ctx; + PySimpleMeta::make_class(ctx); + extend_module!(vm, module, { + "_CData" => PyCData::make_class(ctx), + "_SimpleCData" => PyCSimple::make_class(ctx), + "Array" => array::PyCArray::make_class(ctx), + "CFuncPtr" => function::PyCFuncPtr::make_class(ctx), + "_Pointer" => pointer::PyCPointer::make_class(ctx), + "_pointer_type_cache" => ctx.new_dict(), + "Structure" => structure::PyCStructure::make_class(ctx), + "Union" => union::PyCUnion::make_class(ctx), + }) +} + +pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { + let module = _ctypes::make_module(vm); + extend_module_nodes(vm, &module); + module +} + +#[pymodule] +pub(crate) mod _ctypes { + use super::base::PyCSimple; + use crate::builtins::PyTypeRef; + use crate::class::StaticType; + use crate::function::Either; + use crate::stdlib::ctypes::library; + use crate::{AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine}; + use crossbeam_utils::atomic::AtomicCell; + use std::ffi::{ + c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, + c_ulonglong, + }; + use std::mem; + use widestring::WideChar; + + #[pyattr(name = "__version__")] + const __VERSION__: &str = "1.1.0"; + + // TODO: get properly + #[pyattr(name = "RTLD_LOCAL")] + const RTLD_LOCAL: i32 = 0; + + // TODO: get properly + #[pyattr(name = "RTLD_GLOBAL")] + const RTLD_GLOBAL: i32 = 0; + + #[cfg(target_os = "windows")] + #[pyattr(name = "SIZEOF_TIME_T")] + pub const SIZEOF_TIME_T: usize = 8; + #[cfg(not(target_os = "windows"))] + #[pyattr(name = "SIZEOF_TIME_T")] + pub const SIZEOF_TIME_T: usize = 4; + + #[pyattr(name = "CTYPES_MAX_ARGCOUNT")] + pub const CTYPES_MAX_ARGCOUNT: usize = 1024; + + #[pyattr] + pub const FUNCFLAG_STDCALL: u32 = 0x0; + #[pyattr] + pub const FUNCFLAG_CDECL: u32 = 0x1; + #[pyattr] + pub const FUNCFLAG_HRESULT: u32 = 0x2; + #[pyattr] + pub const FUNCFLAG_PYTHONAPI: u32 = 0x4; + #[pyattr] + pub const FUNCFLAG_USE_ERRNO: u32 = 0x8; + #[pyattr] + pub const FUNCFLAG_USE_LASTERROR: u32 = 0x10; + + #[pyattr] + pub const TYPEFLAG_ISPOINTER: u32 = 0x100; + #[pyattr] + pub const TYPEFLAG_HASPOINTER: u32 = 0x200; + + #[pyattr] + pub const DICTFLAG_FINAL: u32 = 0x1000; + + #[pyattr(name = "ArgumentError", once)] + fn argument_error(vm: &VirtualMachine) -> PyTypeRef { + vm.ctx.new_exception_type( + "_ctypes", + "ArgumentError", + Some(vec![vm.ctx.exceptions.exception_type.to_owned()]), + ) + } + + #[pyattr(name = "FormatError", once)] + fn format_error(vm: &VirtualMachine) -> PyTypeRef { + vm.ctx.new_exception_type( + "_ctypes", + "FormatError", + Some(vec![vm.ctx.exceptions.exception_type.to_owned()]), + ) + } + + pub fn get_size(ty: &str) -> usize { + match ty { + "u" => mem::size_of::(), + "c" | "b" => mem::size_of::(), + "h" => mem::size_of::(), + "H" => mem::size_of::(), + "i" => mem::size_of::(), + "I" => mem::size_of::(), + "l" => mem::size_of::(), + "q" => mem::size_of::(), + "L" => mem::size_of::(), + "Q" => mem::size_of::(), + "f" => mem::size_of::(), + "d" | "g" => mem::size_of::(), + "?" | "B" => mem::size_of::(), + "P" | "z" | "Z" => mem::size_of::(), + _ => unreachable!(), + } + } + + const SIMPLE_TYPE_CHARS: &str = "cbBhHiIlLdfguzZPqQ?"; + + pub fn new_simple_type( + cls: Either<&PyObjectRef, &PyTypeRef>, + vm: &VirtualMachine, + ) -> PyResult { + let cls = match cls { + Either::A(obj) => obj, + Either::B(typ) => typ.as_object(), + }; + + if let Ok(_type_) = cls.get_attr("_type_", vm) { + if _type_.is_instance((&vm.ctx.types.str_type).as_ref(), vm)? { + let tp_str = _type_.str(vm)?.to_string(); + + if tp_str.len() != 1 { + Err(vm.new_value_error( + format!("class must define a '_type_' attribute which must be a string of length 1, str: {tp_str}"), + )) + } else if !SIMPLE_TYPE_CHARS.contains(tp_str.as_str()) { + Err(vm.new_attribute_error(format!("class must define a '_type_' attribute which must be\n a single character string containing one of {SIMPLE_TYPE_CHARS}, currently it is {tp_str}."))) + } else { + Ok(PyCSimple { + _type_: tp_str, + value: AtomicCell::new(vm.ctx.none()), + }) + } + } else { + Err(vm.new_type_error("class must define a '_type_' string attribute".to_string())) + } + } else { + Err(vm.new_attribute_error("class must define a '_type_' attribute".to_string())) + } + } + + #[pyfunction(name = "sizeof")] + pub fn size_of(tp: Either, vm: &VirtualMachine) -> PyResult { + match tp { + Either::A(type_) if type_.fast_issubclass(PyCSimple::static_type()) => { + let zelf = new_simple_type(Either::B(&type_), vm)?; + Ok(get_size(zelf._type_.as_str())) + } + Either::B(obj) if obj.has_attr("size_of_instances", vm)? => { + let size_of_method = obj.get_attr("size_of_instances", vm)?; + let size_of_return = size_of_method.call(vec![], vm)?; + Ok(usize::try_from_object(vm, size_of_return)?) + } + _ => Err(vm.new_type_error("this type has no size".to_string())), + } + } + + #[pyfunction(name = "LoadLibrary")] + fn load_library(name: String, vm: &VirtualMachine) -> PyResult { + // TODO: audit functions first + let cache = library::libcache(); + let mut cache_write = cache.write(); + let lib_ref = cache_write.get_or_insert_lib(&name, vm).unwrap(); + Ok(lib_ref.get_pointer()) + } + + #[pyfunction(name = "FreeLibrary")] + fn free_library(handle: usize) -> PyResult<()> { + let cache = library::libcache(); + let mut cache_write = cache.write(); + cache_write.drop_lib(handle); + Ok(()) + } + + #[pyfunction(name = "POINTER")] + pub fn pointer(_cls: PyTypeRef) {} + + #[pyfunction] + pub fn pointer_fn(_inst: PyObjectRef) {} + + #[cfg(target_os = "windows")] + #[pyfunction(name = "_check_HRESULT")] + pub fn check_hresult(_self: PyObjectRef, hr: i32, _vm: &VirtualMachine) -> PyResult { + // TODO: fixme + if hr < 0 { + // vm.ctx.new_windows_error(hr) + todo!(); + } else { + Ok(hr) + } + } + + #[pyfunction] + fn get_errno() -> i32 { + errno::errno().0 + } + + #[pyfunction] + fn set_errno(value: i32) { + errno::set_errno(errno::Errno(value)); + } +} diff --git a/vm/src/stdlib/ctypes/array.rs b/vm/src/stdlib/ctypes/array.rs new file mode 100644 index 0000000000..8b023582c9 --- /dev/null +++ b/vm/src/stdlib/ctypes/array.rs @@ -0,0 +1,5 @@ +#[pyclass(name = "Array", module = "_ctypes")] +pub struct PyCArray {} + +#[pyclass(flags(BASETYPE, IMMUTABLETYPE))] +impl PyCArray {} diff --git a/vm/src/stdlib/ctypes/base.rs b/vm/src/stdlib/ctypes/base.rs new file mode 100644 index 0000000000..a4147c62b2 --- /dev/null +++ b/vm/src/stdlib/ctypes/base.rs @@ -0,0 +1,224 @@ +use crate::builtins::PyType; +use crate::builtins::{PyBytes, PyFloat, PyInt, PyNone, PyStr, PyTypeRef}; +use crate::convert::ToPyObject; +use crate::function::{Either, OptionalArg}; +use crate::stdlib::ctypes::_ctypes::new_simple_type; +use crate::types::Constructor; +use crate::{AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine}; +use crossbeam_utils::atomic::AtomicCell; +use num_traits::ToPrimitive; +use rustpython_common::lock::PyRwLock; +use std::fmt::Debug; + +#[allow(dead_code)] +fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + match _type_ { + "c" => { + if value + .clone() + .downcast_exact::(vm) + .is_ok_and(|v| v.len() == 1) + || value + .clone() + .downcast_exact::(vm) + .is_ok_and(|v| v.len() == 1) + || value + .clone() + .downcast_exact::(vm) + .map_or(Ok(false), |v| { + let n = v.as_bigint().to_i64(); + if let Some(n) = n { + Ok((0..=255).contains(&n)) + } else { + Ok(false) + } + })? + { + Ok(value.clone()) + } else { + Err(vm.new_type_error( + "one character bytes, bytearray or integer expected".to_string(), + )) + } + } + "u" => { + if let Ok(b) = value.str(vm).map(|v| v.to_string().chars().count() == 1) { + if b { + Ok(value.clone()) + } else { + Err(vm.new_type_error("one character unicode string expected".to_string())) + } + } else { + Err(vm.new_type_error(format!( + "unicode string expected instead of {} instance", + value.class().name() + ))) + } + } + "b" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { + if value.clone().downcast_exact::(vm).is_ok() { + Ok(value.clone()) + } else { + Err(vm.new_type_error(format!( + "an integer is required (got type {})", + value.class().name() + ))) + } + } + "f" | "d" | "g" => { + if value.clone().downcast_exact::(vm).is_ok() { + Ok(value.clone()) + } else { + Err(vm.new_type_error(format!("must be real number, not {}", value.class().name()))) + } + } + "?" => Ok(PyObjectRef::from( + vm.ctx.new_bool(value.clone().try_to_bool(vm)?), + )), + "B" => { + if value.clone().downcast_exact::(vm).is_ok() { + Ok(vm.new_pyobj(u8::try_from_object(vm, value.clone())?)) + } else { + Err(vm.new_type_error(format!("int expected instead of {}", value.class().name()))) + } + } + "z" => { + if value.clone().downcast_exact::(vm).is_ok() + || value.clone().downcast_exact::(vm).is_ok() + { + Ok(value.clone()) + } else { + Err(vm.new_type_error(format!( + "bytes or integer address expected instead of {} instance", + value.class().name() + ))) + } + } + "Z" => { + if value.clone().downcast_exact::(vm).is_ok() { + Ok(value.clone()) + } else { + Err(vm.new_type_error(format!( + "unicode string or integer address expected instead of {} instance", + value.class().name() + ))) + } + } + _ => { + // "P" + if value.clone().downcast_exact::(vm).is_ok() + || value.clone().downcast_exact::(vm).is_ok() + { + Ok(value.clone()) + } else { + Err(vm.new_type_error("cannot be converted to pointer".to_string())) + } + } + } +} + +pub struct RawBuffer { + #[allow(dead_code)] + pub inner: Box<[u8]>, + #[allow(dead_code)] + pub size: usize, +} + +#[pyclass(name = "_CData", module = "_ctypes")] +pub struct PyCData { + _objects: AtomicCell>, + _buffer: PyRwLock, +} + +#[pyclass] +impl PyCData {} + +#[pyclass(module = "_ctypes", name = "PyCSimpleType", base = "PyType")] +pub struct PySimpleMeta {} + +#[pyclass(flags(BASETYPE))] +impl PySimpleMeta { + #[allow(clippy::new_ret_no_self)] + #[pymethod] + fn new(cls: PyTypeRef, _: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(PyObjectRef::from( + new_simple_type(Either::B(&cls), vm)? + .into_ref_with_type(vm, cls)? + .clone(), + )) + } +} + +#[pyclass( + name = "_SimpleCData", + base = "PyCData", + module = "_ctypes", + metaclass = "PySimpleMeta" +)] +#[derive(PyPayload)] +pub struct PyCSimple { + pub _type_: String, + pub value: AtomicCell, +} + +impl Debug for PyCSimple { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyCSimple") + .field("_type_", &self._type_) + .finish() + } +} + +impl Constructor for PyCSimple { + type Args = (OptionalArg,); + + fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { + let attributes = cls.get_attributes(); + let _type_ = attributes + .iter() + .find(|(k, _)| k.to_object().str(vm).unwrap().to_string() == *"_type_") + .unwrap() + .1 + .str(vm)? + .to_string(); + let value = if let Some(ref v) = args.0.into_option() { + set_primitive(_type_.as_str(), v, vm)? + } else { + match _type_.as_str() { + "c" | "u" => PyObjectRef::from(vm.ctx.new_bytes(vec![0])), + "b" | "B" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { + PyObjectRef::from(vm.ctx.new_int(0)) + } + "f" | "d" | "g" => PyObjectRef::from(vm.ctx.new_float(0.0)), + "?" => PyObjectRef::from(vm.ctx.new_bool(false)), + _ => vm.ctx.none(), // "z" | "Z" | "P" + } + }; + Ok(PyCSimple { + _type_, + value: AtomicCell::new(value), + } + .to_pyobject(vm)) + } +} + +#[pyclass(flags(BASETYPE), with(Constructor))] +impl PyCSimple { + #[pygetset(name = "value")] + pub fn value(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let zelf: &Py = instance + .downcast_ref() + .ok_or_else(|| vm.new_type_error("cannot get value of instance".to_string()))?; + Ok(unsafe { (*zelf.value.as_ptr()).clone() }) + } + + #[pygetset(name = "value", setter)] + fn set_value(instance: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let zelf: PyRef = instance + .downcast() + .map_err(|_| vm.new_type_error("cannot set value of instance".to_string()))?; + let content = set_primitive(zelf._type_.as_str(), &value, vm)?; + zelf.value.store(content); + Ok(()) + } +} diff --git a/vm/src/stdlib/ctypes/function.rs b/vm/src/stdlib/ctypes/function.rs new file mode 100644 index 0000000000..a7ee07744b --- /dev/null +++ b/vm/src/stdlib/ctypes/function.rs @@ -0,0 +1,24 @@ +use crate::PyObjectRef; +use crate::stdlib::ctypes::PyCData; +use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::PyRwLock; +use std::ffi::c_void; + +#[derive(Debug)] +pub struct Function { + _pointer: *mut c_void, + _arguments: Vec<()>, + _return_type: Box<()>, +} + +#[pyclass(module = "_ctypes", name = "CFuncPtr", base = "PyCData")] +pub struct PyCFuncPtr { + pub _name_: String, + pub _argtypes_: AtomicCell>, + pub _restype_: AtomicCell, + _handle: PyObjectRef, + _f: PyRwLock, +} + +#[pyclass] +impl PyCFuncPtr {} diff --git a/vm/src/stdlib/ctypes/library.rs b/vm/src/stdlib/ctypes/library.rs new file mode 100644 index 0000000000..94b6327440 --- /dev/null +++ b/vm/src/stdlib/ctypes/library.rs @@ -0,0 +1,115 @@ +use crate::VirtualMachine; +use crossbeam_utils::atomic::AtomicCell; +use libloading::Library; +use rustpython_common::lock::PyRwLock; +use std::collections::HashMap; +use std::ffi::c_void; +use std::fmt; +use std::ptr::null; + +pub struct SharedLibrary { + lib: AtomicCell>, +} + +impl fmt::Debug for SharedLibrary { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SharedLibrary") + } +} + +impl SharedLibrary { + pub fn new(name: &str) -> Result { + Ok(SharedLibrary { + lib: AtomicCell::new(Some(unsafe { Library::new(name)? })), + }) + } + + #[allow(dead_code)] + pub fn get_sym(&self, name: &str) -> Result<*mut c_void, String> { + if let Some(inner) = unsafe { &*self.lib.as_ptr() } { + unsafe { + inner + .get(name.as_bytes()) + .map(|f: libloading::Symbol<*mut c_void>| *f) + .map_err(|err| err.to_string()) + } + } else { + Err("The library has been closed".to_string()) + } + } + + pub fn get_pointer(&self) -> usize { + if let Some(l) = unsafe { &*self.lib.as_ptr() } { + l as *const Library as usize + } else { + null::() as usize + } + } + + pub fn is_closed(&self) -> bool { + unsafe { &*self.lib.as_ptr() }.is_none() + } + + pub fn close(&self) { + let old = self.lib.take(); + self.lib.store(None); + drop(old); + } +} + +impl Drop for SharedLibrary { + fn drop(&mut self) { + self.close(); + } +} + +pub struct ExternalLibs { + libraries: HashMap, +} + +impl ExternalLibs { + pub fn new() -> Self { + Self { + libraries: HashMap::new(), + } + } + + #[allow(dead_code)] + pub fn get_lib(&self, key: usize) -> Option<&SharedLibrary> { + self.libraries.get(&key) + } + + pub fn get_or_insert_lib( + &mut self, + library_path: &str, + _vm: &VirtualMachine, + ) -> Result<&SharedLibrary, libloading::Error> { + let nlib = SharedLibrary::new(library_path)?; + let key = nlib.get_pointer(); + + match self.libraries.get(&key) { + Some(l) => { + if l.is_closed() { + self.libraries.insert(key, nlib); + } + } + _ => { + self.libraries.insert(key, nlib); + } + }; + + Ok(self.libraries.get(&key).unwrap()) + } + + pub fn drop_lib(&mut self, key: usize) { + self.libraries.remove(&key); + } +} + +rustpython_common::static_cell! { + static LIBCACHE: PyRwLock; +} + +pub fn libcache() -> &'static PyRwLock { + LIBCACHE.get_or_init(|| PyRwLock::new(ExternalLibs::new())) +} diff --git a/vm/src/stdlib/ctypes/pointer.rs b/vm/src/stdlib/ctypes/pointer.rs new file mode 100644 index 0000000000..d1360f9862 --- /dev/null +++ b/vm/src/stdlib/ctypes/pointer.rs @@ -0,0 +1,5 @@ +#[pyclass(name = "Pointer", module = "_ctypes")] +pub struct PyCPointer {} + +#[pyclass(flags(BASETYPE, IMMUTABLETYPE))] +impl PyCPointer {} diff --git a/vm/src/stdlib/ctypes/structure.rs b/vm/src/stdlib/ctypes/structure.rs new file mode 100644 index 0000000000..13cca6c260 --- /dev/null +++ b/vm/src/stdlib/ctypes/structure.rs @@ -0,0 +1,5 @@ +#[pyclass(name = "Structure", module = "_ctypes")] +pub struct PyCStructure {} + +#[pyclass(flags(BASETYPE, IMMUTABLETYPE))] +impl PyCStructure {} diff --git a/vm/src/stdlib/ctypes/union.rs b/vm/src/stdlib/ctypes/union.rs new file mode 100644 index 0000000000..5a39d9062e --- /dev/null +++ b/vm/src/stdlib/ctypes/union.rs @@ -0,0 +1,5 @@ +#[pyclass(name = "Union", module = "_ctypes")] +pub struct PyCUnion {} + +#[pyclass(flags(BASETYPE, IMMUTABLETYPE))] +impl PyCUnion {} diff --git a/vm/src/stdlib/errno.rs b/vm/src/stdlib/errno.rs index a142d68a34..247e2a2340 100644 --- a/vm/src/stdlib/errno.rs +++ b/vm/src/stdlib/errno.rs @@ -1,4 +1,4 @@ -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; #[pymodule] mod errno {} @@ -38,9 +38,9 @@ pub mod errors { WSAEPROVIDERFAILEDINIT, WSAEREFUSED, WSAEREMOTE, WSAESHUTDOWN, WSAESOCKTNOSUPPORT, WSAESTALE, WSAETIMEDOUT, WSAETOOMANYREFS, WSAEUSERS, WSAEWOULDBLOCK, WSAHOST_NOT_FOUND, WSAID_ACCEPTEX, WSAID_CONNECTEX, WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, - WSAID_TRANSMITFILE, WSAID_TRANSMITPACKETS, WSAID_WSAPOLL, WSAID_WSARECVMSG, - WSANOTINITIALISED, WSANO_DATA, WSANO_RECOVERY, WSAPROTOCOL_LEN, WSASERVICE_NOT_FOUND, - WSASYSCALLFAILURE, WSASYSNOTREADY, WSASYS_STATUS_LEN, WSATRY_AGAIN, WSATYPE_NOT_FOUND, + WSAID_TRANSMITFILE, WSAID_TRANSMITPACKETS, WSAID_WSAPOLL, WSAID_WSARECVMSG, WSANO_DATA, + WSANO_RECOVERY, WSANOTINITIALISED, WSAPROTOCOL_LEN, WSASERVICE_NOT_FOUND, + WSASYS_STATUS_LEN, WSASYSCALLFAILURE, WSASYSNOTREADY, WSATRY_AGAIN, WSATYPE_NOT_FOUND, WSAVERNOTSUPPORTED, }, }; diff --git a/vm/src/stdlib/functools.rs b/vm/src/stdlib/functools.rs index d13b9b84f6..145d95d6ff 100644 --- a/vm/src/stdlib/functools.rs +++ b/vm/src/stdlib/functools.rs @@ -2,7 +2,7 @@ pub(crate) use _functools::make_module; #[pymodule] mod _functools { - use crate::{function::OptionalArg, protocol::PyIter, PyObjectRef, PyResult, VirtualMachine}; + use crate::{PyObjectRef, PyResult, VirtualMachine, function::OptionalArg, protocol::PyIter}; #[pyfunction] fn reduce( diff --git a/vm/src/stdlib/imp.rs b/vm/src/stdlib/imp.rs index d8727e74ff..5c3f4bf61d 100644 --- a/vm/src/stdlib/imp.rs +++ b/vm/src/stdlib/imp.rs @@ -1,11 +1,11 @@ use crate::frozen::FrozenModule; -use crate::{builtins::PyBaseExceptionRef, VirtualMachine}; +use crate::{VirtualMachine, builtins::PyBaseExceptionRef}; pub(crate) use _imp::make_module; #[cfg(feature = "threading")] #[pymodule(sub)] mod lock { - use crate::{stdlib::thread::RawRMutex, PyResult, VirtualMachine}; + use crate::{PyResult, VirtualMachine, stdlib::thread::RawRMutex}; static IMP_LOCK: RawRMutex = RawRMutex::INIT; @@ -60,7 +60,9 @@ impl FrozenError { use FrozenError::*; let msg = match self { BadName | NotFound => format!("No such frozen object named {mod_name}"), - Disabled => format!("Frozen modules are disabled and the frozen object named {mod_name} is not essential"), + Disabled => format!( + "Frozen modules are disabled and the frozen object named {mod_name} is not essential" + ), Excluded => format!("Excluded frozen object named {mod_name}"), Invalid => format!("Frozen object named {mod_name} is invalid"), }; @@ -80,9 +82,10 @@ fn find_frozen(name: &str, vm: &VirtualMachine) -> Result PyBaseExceptionRef { @@ -118,6 +118,8 @@ impl std::os::fd::AsRawFd for Fildes { mod _io { use super::*; use crate::{ + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromBorrowedObject, TryFromObject, builtins::{ PyBaseExceptionRef, PyByteArray, PyBytes, PyBytesRef, PyIntRef, PyMemoryView, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, @@ -140,8 +142,6 @@ mod _io { Callable, Constructor, DefaultConstructor, Destructor, Initializer, IterNext, Iterable, }, vm::VirtualMachine, - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromBorrowedObject, TryFromObject, }; use bstr::ByteSlice; use crossbeam_utils::atomic::AtomicCell; @@ -149,7 +149,7 @@ mod _io { use num_traits::ToPrimitive; use std::{ borrow::Cow, - io::{self, prelude::*, Cursor, SeekFrom}, + io::{self, Cursor, SeekFrom, prelude::*}, ops::Range, }; @@ -319,11 +319,7 @@ mod _io { // For Cursor, fill_buf returns all of the remaining data unlike other BufReads which have outer reading source. // Unless we add other data by write, there will be no more data. let buf = self.cursor.fill_buf().map_err(|err| os_err(vm, err))?; - if size < buf.len() { - &buf[..size] - } else { - buf - } + if size < buf.len() { &buf[..size] } else { buf } }; let buf = match available.find_byte(byte) { Some(i) => available[..=i].to_vec(), @@ -482,7 +478,7 @@ mod _io { let size = size.to_usize(); let read = instance.get_attr("read", vm)?; let mut res = Vec::new(); - while size.map_or(true, |s| res.len() < s) { + while size.is_none_or(|s| res.len() < s) { let read_res = ArgBytesLike::try_from_object(vm, read.call((1,), vm)?)?; if read_res.with_ref(|b| b.is_empty()) { break; @@ -2496,7 +2492,7 @@ mod _io { _ => { return Err( vm.new_value_error(format!("invalid whence ({how}, should be 0, 1 or 2)")) - ) + ); } }; use crate::types::PyComparisonOp; @@ -3006,7 +3002,7 @@ mod _io { } fn parse_decoder_state(state: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyBytesRef, i32)> { - use crate::builtins::{int, PyTuple}; + use crate::builtins::{PyTuple, int}; let state_err = || vm.new_type_error("illegal decoder state".to_owned()); let state = state.downcast::().map_err(|_| state_err())?; match state.as_slice() { @@ -4032,8 +4028,9 @@ mod _io { #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] #[pymodule] mod fileio { - use super::{Offset, _io::*}; + use super::{_io::*, Offset}; use crate::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyStr, PyStrRef}, common::crt_fd::Fd, convert::ToPyException, @@ -4041,7 +4038,6 @@ mod fileio { ospath::{IOErrorBuilder, OsPath, OsPathOrFd}, stdlib::os, types::{Constructor, DefaultConstructor, Initializer, Representable}, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use std::io::{Read, Write}; diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index eb800a0b1f..65c1482057 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -4,9 +4,11 @@ pub(crate) use decl::make_module; mod decl { use crate::stdlib::itertools::decl::int::get_value; use crate::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, TryFromObject, + VirtualMachine, builtins::{ - int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, - PyTypeRef, + PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef, int, + tuple::IntoPyTuple, }, common::{ lock::{PyMutex, PyRwLock, PyRwLockWriteGuard}, @@ -18,8 +20,6 @@ mod decl { protocol::{PyIter, PyIterReturn, PyNumber}, stdlib::sys, types::{Constructor, IterNext, Iterable, Representable, SelfIter}, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, TryFromObject, - VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use malachite_bigint::BigInt; diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index 35f9b93d5f..fd7332e7c2 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -5,6 +5,7 @@ mod decl { use crate::builtins::code::{CodeObject, Literal, PyObjBag}; use crate::class::StaticType; use crate::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{ PyBool, PyByteArray, PyBytes, PyCode, PyComplex, PyDict, PyEllipsis, PyFloat, PyFrozenSet, PyInt, PyList, PyNone, PySet, PyStopIteration, PyStr, PyTuple, @@ -13,7 +14,6 @@ mod decl { function::{ArgBytesLike, OptionalArg}, object::AsObject, protocol::PyBuffer, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use malachite_bigint::BigInt; use num_complex::Complex64; diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 12baee11f7..529a40e861 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -37,6 +37,8 @@ pub mod posix; #[path = "posix_compat.rs"] pub mod posix; +#[cfg(any(target_family = "unix", target_family = "windows"))] +mod ctypes; #[cfg(windows)] pub(crate) mod msvcrt; #[cfg(all(unix, not(any(target_os = "android", target_os = "redox"))))] @@ -48,7 +50,7 @@ mod winapi; #[cfg(windows)] mod winreg; -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; use std::{borrow::Cow, collections::HashMap}; pub type StdlibInitFunc = Box PyRef)>; @@ -124,5 +126,9 @@ pub fn get_module_inits() -> StdlibMap { "_winapi" => winapi::make_module, "winreg" => winreg::make_module, } + #[cfg(any(target_family = "unix", target_family = "windows"))] + { + "_ctypes" => ctypes::make_module, + } } } diff --git a/vm/src/stdlib/msvcrt.rs b/vm/src/stdlib/msvcrt.rs index 03ddb44f22..7b3620ad51 100644 --- a/vm/src/stdlib/msvcrt.rs +++ b/vm/src/stdlib/msvcrt.rs @@ -3,10 +3,10 @@ pub use msvcrt::*; #[pymodule] mod msvcrt { use crate::{ + PyRef, PyResult, VirtualMachine, builtins::{PyBytes, PyStrRef}, common::suppress_iph, stdlib::os::errno_err, - PyRef, PyResult, VirtualMachine, }; use itertools::Itertools; use windows_sys::Win32::{ @@ -24,7 +24,7 @@ mod msvcrt { unsafe { suppress_iph!(_setmode(fd, libc::O_BINARY)) }; } - extern "C" { + unsafe extern "C" { fn _getch() -> i32; fn _getwch() -> u32; fn _getche() -> i32; @@ -70,7 +70,7 @@ mod msvcrt { Ok(()) } - extern "C" { + unsafe extern "C" { fn _setmode(fd: i32, flags: i32) -> i32; } @@ -84,7 +84,7 @@ mod msvcrt { } } - extern "C" { + unsafe extern "C" { fn _open_osfhandle(osfhandle: isize, flags: i32) -> i32; fn _get_osfhandle(fd: i32) -> libc::intptr_t; } diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 803cade630..34fa8792d5 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -1,4 +1,4 @@ -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub use module::raw_set_handle_inheritable; @@ -11,13 +11,13 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule(name = "nt", with(super::os::_os))] pub(crate) mod module { use crate::{ + PyResult, TryFromObject, VirtualMachine, builtins::{PyDictRef, PyListRef, PyStrRef, PyTupleRef}, common::{crt_fd::Fd, os::last_os_error, suppress_iph}, convert::ToPyException, function::{Either, OptionalArg}, ospath::OsPath, - stdlib::os::{errno_err, DirFd, FollowSymlinks, SupportFunc, TargetIsDirectory, _os}, - PyResult, TryFromObject, VirtualMachine, + stdlib::os::{_os, DirFd, FollowSymlinks, SupportFunc, TargetIsDirectory, errno_err}, }; use libc::intptr_t; use std::{ @@ -43,6 +43,12 @@ pub(crate) mod module { || attr & FileSystem::FILE_ATTRIBUTE_DIRECTORY != 0)) } + #[pyfunction] + pub(super) fn _supports_virtual_terminal() -> PyResult { + // TODO: implement this + Ok(true) + } + #[derive(FromArgs)] pub(super) struct SymlinkArgs { src: OsPath, @@ -110,7 +116,7 @@ pub(crate) mod module { // cwait is available on MSVC only (according to CPython) #[cfg(target_env = "msvc")] - extern "C" { + unsafe extern "C" { fn _cwait(termstat: *mut i32, procHandle: intptr_t, action: i32) -> intptr_t; } @@ -188,7 +194,7 @@ pub(crate) mod module { } #[cfg(target_env = "msvc")] - extern "C" { + unsafe extern "C" { fn _wexecv(cmdname: *const u16, argv: *const *const u16) -> intptr_t; } diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs index 19c1f4b61b..d1a4b376e8 100644 --- a/vm/src/stdlib/operator.rs +++ b/vm/src/stdlib/operator.rs @@ -4,6 +4,7 @@ pub(crate) use _operator::make_module; mod _operator { use crate::common::cmp; use crate::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyInt, PyIntRef, PyStr, PyStrRef, PyTupleRef, PyTypeRef}, function::Either, function::{ArgBytesLike, FuncArgs, KwArgs, OptionalArg}, @@ -11,7 +12,6 @@ mod _operator { protocol::PyIter, recursion::ReprGuard, types::{Callable, Constructor, PyComparisonOp, Representable}, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyfunction] @@ -336,7 +336,7 @@ mod _operator { _ => { return Err(vm.new_type_error( "unsupported operand types(s) or combination of types".to_owned(), - )) + )); } }; Ok(res) diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 0b3f617552..39701cb3a3 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1,9 +1,9 @@ use crate::{ + AsObject, Py, PyPayload, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyModule, PySet}, common::crt_fd::Fd, convert::ToPyException, function::{ArgumentError, FromArgs, FuncArgs}, - AsObject, Py, PyPayload, PyResult, VirtualMachine, }; use std::{ffi, fs, io, path::Path}; @@ -125,8 +125,9 @@ fn bytes_as_osstr<'a>(b: &'a [u8], vm: &VirtualMachine) -> PyResult<&'a ffi::OsS #[pymodule(sub)] pub(super) mod _os { - use super::{errno_err, DirFd, FollowSymlinks, SupportFunc}; + use super::{DirFd, FollowSymlinks, SupportFunc, errno_err}; use crate::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, builtins::{ PyBytesRef, PyGenericAlias, PyIntRef, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }, @@ -144,7 +145,6 @@ pub(super) mod _os { types::{IterNext, Iterable, PyStructSequence, Representable, SelfIter}, utils::ToCString, vm::VirtualMachine, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; @@ -303,11 +303,7 @@ pub(super) mod _os { #[cfg(target_os = "redox")] let [] = dir_fd.0; let res = unsafe { libc::mkdir(path.as_ptr(), mode as _) }; - if res < 0 { - Err(errno_err(vm)) - } else { - Ok(()) - } + if res < 0 { Err(errno_err(vm)) } else { Ok(()) } } #[pyfunction] @@ -380,8 +376,8 @@ pub(super) mod _os { fn env_bytes_as_bytes(obj: &Either) -> &[u8] { match obj { - Either::A(ref s) => s.as_str().as_bytes(), - Either::B(ref b) => b.as_bytes(), + Either::A(s) => s.as_str().as_bytes(), + Either::B(b) => b.as_bytes(), } } @@ -401,7 +397,8 @@ pub(super) mod _os { } let key = super::bytes_as_osstr(key, vm)?; let value = super::bytes_as_osstr(value, vm)?; - env::set_var(key, value); + // SAFETY: requirements forwarded from the caller + unsafe { env::set_var(key, value) }; Ok(()) } @@ -421,7 +418,8 @@ pub(super) mod _os { )); } let key = super::bytes_as_osstr(key, vm)?; - env::remove_var(key); + // SAFETY: requirements forwarded from the caller + unsafe { env::remove_var(key) }; Ok(()) } @@ -966,7 +964,7 @@ pub(super) mod _os { #[pyfunction] fn abort() { - extern "C" { + unsafe extern "C" { fn abort(); } unsafe { abort() } @@ -978,10 +976,7 @@ pub(super) mod _os { return Err(vm.new_value_error("negative argument not allowed".to_owned())); } let mut buf = vec![0u8; size as usize]; - getrandom::getrandom(&mut buf).map_err(|e| match e.raw_os_error() { - Some(errno) => io::Error::from_raw_os_error(errno).into_pyexception(vm), - None => vm.new_os_error("Getting random failed".to_owned()), - })?; + getrandom::fill(&mut buf).map_err(|e| io::Error::from(e).into_pyexception(vm))?; Ok(buf) } @@ -1012,11 +1007,7 @@ pub(super) mod _os { std::mem::transmute::<[i32; 2], i64>(distance_to_move) } }; - if res < 0 { - Err(errno_err(vm)) - } else { - Ok(res) - } + if res < 0 { Err(errno_err(vm)) } else { Ok(res) } } #[pyfunction] @@ -1101,7 +1092,7 @@ pub(super) mod _os { (Some(_), Some(_)) => { return Err(vm.new_value_error( "utime: you may specify either 'times' or 'ns' but not both".to_owned(), - )) + )); } }; utime_impl(args.path, acc, modif, args.dir_fd, args.follow_symlinks, vm) @@ -1139,11 +1130,7 @@ pub(super) mod _os { }, ) }; - if ret < 0 { - Err(errno_err(vm)) - } else { - Ok(()) - } + if ret < 0 { Err(errno_err(vm)) } else { Ok(()) } } #[cfg(target_os = "redox")] { diff --git a/vm/src/stdlib/posix.rs b/vm/src/stdlib/posix.rs index 1b843bab60..05b1d8addd 100644 --- a/vm/src/stdlib/posix.rs +++ b/vm/src/stdlib/posix.rs @@ -1,4 +1,4 @@ -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; use std::os::unix::io::RawFd; pub fn raw_set_inheritable(fd: RawFd, inheritable: bool) -> nix::Result<()> { @@ -21,16 +21,16 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule(name = "posix", with(super::os::_os))] pub mod module { use crate::{ + AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyDictRef, PyInt, PyListRef, PyStrRef, PyTupleRef, PyTypeRef}, convert::{IntoPyException, ToPyObject, TryFromObject}, function::{Either, KwArgs, OptionalArg}, ospath::{IOErrorBuilder, OsPath, OsPathOrFd}, stdlib::os::{ - errno_err, DirFd, FollowSymlinks, SupportFunc, TargetIsDirectory, _os, fs_metadata, + _os, DirFd, FollowSymlinks, SupportFunc, TargetIsDirectory, errno_err, fs_metadata, }, types::{Constructor, Representable}, utils::ToCString, - AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; use bitflags::bitflags; use nix::{ @@ -352,11 +352,7 @@ pub mod module { { let [] = args.dir_fd.0; let res = unsafe { libc::symlink(src.as_ptr(), dst.as_ptr()) }; - if res < 0 { - Err(errno_err(vm)) - } else { - Ok(()) - } + if res < 0 { Err(errno_err(vm)) } else { Ok(()) } } } @@ -604,21 +600,13 @@ pub mod module { ) }, }; - if ret != 0 { - Err(errno_err(vm)) - } else { - Ok(()) - } + if ret != 0 { Err(errno_err(vm)) } else { Ok(()) } } #[cfg(target_vendor = "apple")] fn mknod(self, vm: &VirtualMachine) -> PyResult<()> { let [] = self.dir_fd.0; let ret = self._mknod(vm)?; - if ret != 0 { - Err(errno_err(vm)) - } else { - Ok(()) - } + if ret != 0 { Err(errno_err(vm)) } else { Ok(()) } } } @@ -863,7 +851,7 @@ pub mod module { #[pyfunction] fn set_blocking(fd: RawFd, blocking: bool, vm: &VirtualMachine) -> PyResult<()> { let _set_flag = || { - use nix::fcntl::{fcntl, FcntlArg, OFlag}; + use nix::fcntl::{FcntlArg, OFlag, fcntl}; let flags = OFlag::from_bits_truncate(fcntl(fd, FcntlArg::F_GETFL)?); let mut new_flags = flags; @@ -971,7 +959,7 @@ pub mod module { #[cfg(any(target_os = "macos", target_os = "freebsd", target_os = "netbsd",))] #[pyfunction] fn lchmod(path: OsPath, mode: u32, vm: &VirtualMachine) -> PyResult<()> { - extern "C" { + unsafe extern "C" { fn lchmod(path: *const libc::c_char, mode: libc::mode_t) -> libc::c_int; } let c_path = path.clone().into_cstring(vm)?; @@ -1605,7 +1593,7 @@ pub mod module { // from libstd: // https://github.com/rust-lang/rust/blob/daecab3a784f28082df90cebb204998051f3557d/src/libstd/sys/unix/fs.rs#L1251 #[cfg(target_os = "macos")] - extern "C" { + unsafe extern "C" { fn fcopyfile( in_fd: libc::c_int, out_fd: libc::c_int, @@ -1618,11 +1606,7 @@ pub mod module { #[pyfunction] fn _fcopyfile(in_fd: i32, out_fd: i32, flags: i32, vm: &VirtualMachine) -> PyResult<()> { let ret = unsafe { fcopyfile(in_fd, out_fd, std::ptr::null_mut(), flags as u32) }; - if ret < 0 { - Err(errno_err(vm)) - } else { - Ok(()) - } + if ret < 0 { Err(errno_err(vm)) } else { Ok(()) } } #[pyfunction] @@ -2299,7 +2283,7 @@ pub mod module { #[cfg(target_os = "linux")] unsafe fn sys_getrandom(buf: *mut libc::c_void, buflen: usize, flags: u32) -> isize { - libc::syscall(libc::SYS_getrandom, buf, buflen, flags as usize) as _ + unsafe { libc::syscall(libc::SYS_getrandom, buf, buflen, flags as usize) as _ } } #[cfg(target_os = "linux")] diff --git a/vm/src/stdlib/posix_compat.rs b/vm/src/stdlib/posix_compat.rs index 95ed932fe5..334aa597ce 100644 --- a/vm/src/stdlib/posix_compat.rs +++ b/vm/src/stdlib/posix_compat.rs @@ -1,5 +1,5 @@ //! `posix` compatible module for `not(any(unix, windows))` -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let module = module::make_module(vm); @@ -10,10 +10,10 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule(name = "posix", with(super::os::_os))] pub(crate) mod module { use crate::{ + PyObjectRef, PyResult, VirtualMachine, builtins::PyStrRef, ospath::OsPath, - stdlib::os::{DirFd, SupportFunc, TargetIsDirectory, _os}, - PyObjectRef, PyResult, VirtualMachine, + stdlib::os::{_os, DirFd, SupportFunc, TargetIsDirectory}, }; use std::env; diff --git a/vm/src/stdlib/pwd.rs b/vm/src/stdlib/pwd.rs index 0edca9c0a6..f6c277242c 100644 --- a/vm/src/stdlib/pwd.rs +++ b/vm/src/stdlib/pwd.rs @@ -3,11 +3,11 @@ pub(crate) use pwd::make_module; #[pymodule] mod pwd { use crate::{ + PyObjectRef, PyResult, VirtualMachine, builtins::{PyIntRef, PyStrRef}, convert::{IntoPyException, ToPyObject}, exceptions, types::PyStructSequence, - PyObjectRef, PyResult, VirtualMachine, }; use nix::unistd::{self, User}; use std::ptr::NonNull; diff --git a/vm/src/stdlib/signal.rs b/vm/src/stdlib/signal.rs index 8c5db53d17..0c47ca082e 100644 --- a/vm/src/stdlib/signal.rs +++ b/vm/src/stdlib/signal.rs @@ -1,4 +1,4 @@ -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let module = _signal::make_module(vm); @@ -13,10 +13,10 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { pub(crate) mod _signal { #[cfg(any(unix, windows))] use crate::{ - convert::{IntoPyException, TryFromBorrowedObject}, Py, + convert::{IntoPyException, TryFromBorrowedObject}, }; - use crate::{signal, PyObjectRef, PyResult, VirtualMachine}; + use crate::{PyObjectRef, PyResult, VirtualMachine, signal}; use std::sync::atomic::{self, Ordering}; #[cfg(any(unix, windows))] @@ -78,7 +78,7 @@ pub(crate) mod _signal { pub const SIG_ERR: sighandler_t = -1 as _; #[cfg(all(unix, not(target_os = "redox")))] - extern "C" { + unsafe extern "C" { fn siginterrupt(sig: i32, flag: i32) -> i32; } diff --git a/vm/src/stdlib/sre.rs b/vm/src/stdlib/sre.rs index 8442002fa5..7d620c13d9 100644 --- a/vm/src/stdlib/sre.rs +++ b/vm/src/stdlib/sre.rs @@ -3,7 +3,8 @@ pub(crate) use _sre::make_module; #[pymodule] mod _sre { use crate::{ - atomic_func, + Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, + TryFromObject, VirtualMachine, atomic_func, builtins::{ PyCallableIterator, PyDictRef, PyGenericAlias, PyInt, PyList, PyListRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, @@ -14,16 +15,14 @@ mod _sre { protocol::{PyBuffer, PyCallable, PyMappingMethods}, stdlib::sys, types::{AsMapping, Comparable, Hashable, Representable}, - Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, - TryFromObject, VirtualMachine, }; use core::str; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_sre_engine::{ - string::{lower_ascii, lower_unicode, upper_unicode}, Request, SearchIter, SreFlag, State, StrDrive, + string::{lower_ascii, lower_unicode, upper_unicode}, }; #[pyattr] diff --git a/vm/src/stdlib/string.rs b/vm/src/stdlib/string.rs index cedff92d96..8a8a182732 100644 --- a/vm/src/stdlib/string.rs +++ b/vm/src/stdlib/string.rs @@ -7,10 +7,10 @@ pub(crate) use _string::make_module; mod _string { use crate::common::ascii; use crate::{ + PyObjectRef, PyResult, VirtualMachine, builtins::{PyList, PyStrRef}, convert::ToPyException, convert::ToPyObject, - PyObjectRef, PyResult, VirtualMachine, }; use rustpython_format::{ FieldName, FieldNamePart, FieldType, FormatPart, FormatString, FromTemplate, diff --git a/vm/src/stdlib/symtable.rs b/vm/src/stdlib/symtable.rs index 10d79e9a8b..13a4105111 100644 --- a/vm/src/stdlib/symtable.rs +++ b/vm/src/stdlib/symtable.rs @@ -3,7 +3,7 @@ pub(crate) use symtable::make_module; #[pymodule] mod symtable { use crate::{ - builtins::PyStrRef, compiler, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyStrRef, compiler, }; use rustpython_codegen::symboltable::{ Symbol, SymbolFlags, SymbolScope, SymbolTable, SymbolTableType, diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 59414b5fdd..dfaab20f2a 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -1,10 +1,11 @@ -use crate::{builtins::PyModule, convert::ToPyObject, Py, PyResult, VirtualMachine}; +use crate::{Py, PyResult, VirtualMachine, builtins::PyModule, convert::ToPyObject}; -pub(crate) use sys::{UnraisableHookArgs, __module_def, DOC, MAXSIZE, MULTIARCH}; +pub(crate) use sys::{__module_def, DOC, MAXSIZE, MULTIARCH, UnraisableHookArgs}; #[pymodule] mod sys { use crate::{ + AsObject, PyObject, PyObjectRef, PyRef, PyRefExact, PyResult, builtins::{ PyBaseExceptionRef, PyDictRef, PyNamespace, PyStr, PyStrRef, PyTupleRef, PyTypeRef, }, @@ -19,7 +20,6 @@ mod sys { types::PyStructSequence, version, vm::{Settings, VirtualMachine}, - AsObject, PyObject, PyObjectRef, PyRef, PyRefExact, PyResult, }; use num_traits::ToPrimitive; use std::{ @@ -85,11 +85,7 @@ mod sys { #[pyattr] fn default_prefix(_vm: &VirtualMachine) -> &'static str { // TODO: the windows one doesn't really make sense - if cfg!(windows) { - "C:" - } else { - "/usr/local" - } + if cfg!(windows) { "C:" } else { "/usr/local" } } #[pyattr] fn prefix(vm: &VirtualMachine) -> &'static str { @@ -574,9 +570,11 @@ mod sys { if vm.is_none(unraisable.exc_type.as_object()) { // TODO: early return, but with what error? } - assert!(unraisable - .exc_type - .fast_issubclass(vm.ctx.exceptions.base_exception_type)); + assert!( + unraisable + .exc_type + .fast_issubclass(vm.ctx.exceptions.base_exception_type) + ); // TODO: print module name and qualname diff --git a/vm/src/stdlib/sysconfigdata.rs b/vm/src/stdlib/sysconfigdata.rs index 929227ac11..90e46b83b9 100644 --- a/vm/src/stdlib/sysconfigdata.rs +++ b/vm/src/stdlib/sysconfigdata.rs @@ -2,7 +2,7 @@ pub(crate) use _sysconfigdata::make_module; #[pymodule] pub(crate) mod _sysconfigdata { - use crate::{builtins::PyDictRef, convert::ToPyObject, stdlib::sys::MULTIARCH, VirtualMachine}; + use crate::{VirtualMachine, builtins::PyDictRef, convert::ToPyObject, stdlib::sys::MULTIARCH}; #[pyattr] fn build_time_vars(vm: &VirtualMachine) -> PyDictRef { diff --git a/vm/src/stdlib/thread.rs b/vm/src/stdlib/thread.rs index 63e66474d1..bca7930437 100644 --- a/vm/src/stdlib/thread.rs +++ b/vm/src/stdlib/thread.rs @@ -1,20 +1,20 @@ //! Implementation of the _thread module #[cfg_attr(target_arch = "wasm32", allow(unused_imports))] -pub(crate) use _thread::{make_module, RawRMutex}; +pub(crate) use _thread::{RawRMutex, make_module}; #[pymodule] pub(crate) mod _thread { use crate::{ + AsObject, Py, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyDictRef, PyStr, PyTupleRef, PyTypeRef}, convert::ToPyException, function::{ArgCallable, Either, FuncArgs, KwArgs, OptionalArg, PySetterValue}, types::{Constructor, GetAttr, Representable, SetAttr}, - AsObject, Py, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use parking_lot::{ - lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, RawMutex, RawThreadId, + lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, }; use std::{cell::RefCell, fmt, thread, time::Duration}; use thread_local::ThreadLocal; diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index 566650d0f2..37a518e504 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -2,7 +2,7 @@ // See also: // https://docs.python.org/3/library/time.html -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub use decl::time; @@ -17,7 +17,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] -extern "C" { +unsafe extern "C" { #[cfg(not(target_os = "freebsd"))] #[link_name = "daylight"] static c_daylight: std::ffi::c_int; @@ -33,16 +33,19 @@ extern "C" { #[pymodule(name = "time", with(platform))] mod decl { use crate::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyStrRef, PyTypeRef}, function::{Either, FuncArgs, OptionalArg}, types::PyStructSequence, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use chrono::{ - naive::{NaiveDate, NaiveDateTime, NaiveTime}, DateTime, Datelike, Timelike, + naive::{NaiveDate, NaiveDateTime, NaiveTime}, }; use std::time::Duration; + #[cfg(target_env = "msvc")] + #[cfg(not(target_arch = "wasm32"))] + use windows::Win32::System::Time; #[allow(dead_code)] pub(super) const SEC_TO_MS: i64 = 1000; @@ -152,6 +155,15 @@ mod decl { Ok(get_perf_time(vm)?.as_nanos()) } + #[cfg(target_env = "msvc")] + #[cfg(not(target_arch = "wasm32"))] + fn get_tz_info() -> Time::TIME_ZONE_INFORMATION { + let mut info = Time::TIME_ZONE_INFORMATION::default(); + let info_ptr = &mut info as *mut Time::TIME_ZONE_INFORMATION; + let _ = unsafe { Time::GetTimeZoneInformation(info_ptr) }; + info + } + // #[pyfunction] // fn tzset() { // unsafe { super::_tzset() }; @@ -164,6 +176,15 @@ mod decl { unsafe { super::c_timezone } } + #[cfg(target_env = "msvc")] + #[cfg(not(target_arch = "wasm32"))] + #[pyattr] + fn timezone(_vm: &VirtualMachine) -> i32 { + let info = get_tz_info(); + // https://users.rust-lang.org/t/accessing-tzname-and-similar-constants-in-windows/125771/3 + (info.Bias + info.StandardBias) * 60 + } + #[cfg(not(target_os = "freebsd"))] #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] @@ -172,6 +193,15 @@ mod decl { unsafe { super::c_daylight } } + #[cfg(target_env = "msvc")] + #[cfg(not(target_arch = "wasm32"))] + #[pyattr] + fn daylight(_vm: &VirtualMachine) -> i32 { + let info = get_tz_info(); + // https://users.rust-lang.org/t/accessing-tzname-and-similar-constants-in-windows/125771/3 + (info.StandardBias != info.DaylightBias) as i32 + } + #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] #[pyattr] @@ -179,11 +209,29 @@ mod decl { use crate::builtins::tuple::IntoPyTuple; unsafe fn to_str(s: *const std::ffi::c_char) -> String { - std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned() + unsafe { std::ffi::CStr::from_ptr(s) } + .to_string_lossy() + .into_owned() } unsafe { (to_str(super::c_tzname[0]), to_str(super::c_tzname[1])) }.into_pytuple(vm) } + #[cfg(target_env = "msvc")] + #[cfg(not(target_arch = "wasm32"))] + #[pyattr] + fn tzname(vm: &VirtualMachine) -> crate::builtins::PyTupleRef { + use crate::builtins::tuple::IntoPyTuple; + let info = get_tz_info(); + let standard = widestring::decode_utf16_lossy(info.StandardName) + .filter(|&c| c != '\0') + .collect::(); + let daylight = widestring::decode_utf16_lossy(info.DaylightName) + .filter(|&c| c != '\0') + .collect::(); + let tz_name = (&*standard, &*daylight); + tz_name.into_pytuple(vm) + } + fn pyobj_to_date_time( value: Either, vm: &VirtualMachine, @@ -461,9 +509,9 @@ mod platform { use super::decl::{SEC_TO_NS, US_TO_NS}; #[cfg_attr(target_os = "macos", allow(unused_imports))] use crate::{ + PyObject, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, builtins::{PyNamespace, PyStrRef}, convert::IntoPyException, - PyObject, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; use nix::{sys::time::TimeSpec, time::ClockId}; use std::time::Duration; @@ -673,7 +721,7 @@ mod platform { target_os = "openbsd", ))] pub(super) fn get_process_time(vm: &VirtualMachine) -> PyResult { - use nix::sys::resource::{getrusage, UsageWho}; + use nix::sys::resource::{UsageWho, getrusage}; fn from_timeval(tv: libc::timeval, vm: &VirtualMachine) -> PyResult { (|tv: libc::timeval| { let t = tv.tv_sec.checked_mul(SEC_TO_NS)?; @@ -695,11 +743,11 @@ mod platform { #[cfg(windows)] #[pymodule] mod platform { - use super::decl::{time_muldiv, MS_TO_NS, SEC_TO_NS}; + use super::decl::{MS_TO_NS, SEC_TO_NS, time_muldiv}; use crate::{ + PyRef, PyResult, VirtualMachine, builtins::{PyNamespace, PyStrRef}, stdlib::os::errno_err, - PyRef, PyResult, VirtualMachine, }; use std::time::Duration; use windows_sys::Win32::{ diff --git a/vm/src/stdlib/typing.rs b/vm/src/stdlib/typing.rs index daa0180325..c266e811ca 100644 --- a/vm/src/stdlib/typing.rs +++ b/vm/src/stdlib/typing.rs @@ -3,9 +3,9 @@ pub(crate) use _typing::make_module; #[pymodule] pub(crate) mod _typing { use crate::{ - builtins::{pystr::AsPyStr, PyGenericAlias, PyTupleRef, PyTypeRef}, - function::IntoFuncArgs, PyObjectRef, PyPayload, PyResult, VirtualMachine, + builtins::{PyGenericAlias, PyTupleRef, PyTypeRef, pystr::AsPyStr}, + function::IntoFuncArgs, }; pub(crate) fn _call_typing_func_object<'a>( @@ -75,18 +75,31 @@ pub(crate) mod _typing { #[pyclass(name = "ParamSpec")] #[derive(Debug, PyPayload)] #[allow(dead_code)] - struct ParamSpec {} + pub(crate) struct ParamSpec { + name: PyObjectRef, + } + #[pyclass(flags(BASETYPE))] impl ParamSpec {} + pub(crate) fn make_paramspec(name: PyObjectRef) -> ParamSpec { + ParamSpec { name } + } + #[pyattr] #[pyclass(name = "TypeVarTuple")] #[derive(Debug, PyPayload)] #[allow(dead_code)] - pub(crate) struct TypeVarTuple {} + pub(crate) struct TypeVarTuple { + name: PyObjectRef, + } #[pyclass(flags(BASETYPE))] impl TypeVarTuple {} + pub(crate) fn make_typevartuple(name: PyObjectRef) -> TypeVarTuple { + TypeVarTuple { name } + } + #[pyattr] #[pyclass(name = "ParamSpecArgs")] #[derive(Debug, PyPayload)] diff --git a/vm/src/stdlib/warnings.rs b/vm/src/stdlib/warnings.rs index be5ad8131e..a8ffee4579 100644 --- a/vm/src/stdlib/warnings.rs +++ b/vm/src/stdlib/warnings.rs @@ -1,6 +1,6 @@ pub(crate) use _warnings::make_module; -use crate::{builtins::PyType, Py, PyResult, VirtualMachine}; +use crate::{Py, PyResult, VirtualMachine, builtins::PyType}; pub fn warn( category: &Py, @@ -20,9 +20,9 @@ pub fn warn( #[pymodule] mod _warnings { use crate::{ + PyResult, VirtualMachine, builtins::{PyStrRef, PyTypeRef}, function::OptionalArg, - PyResult, VirtualMachine, }; #[derive(FromArgs)] diff --git a/vm/src/stdlib/weakref.rs b/vm/src/stdlib/weakref.rs index 3ef0de6155..7d8924ff52 100644 --- a/vm/src/stdlib/weakref.rs +++ b/vm/src/stdlib/weakref.rs @@ -9,8 +9,8 @@ pub(crate) use _weakref::make_module; #[pymodule] mod _weakref { use crate::{ - builtins::{PyDictRef, PyTypeRef, PyWeak}, PyObjectRef, PyResult, VirtualMachine, + builtins::{PyDictRef, PyTypeRef, PyWeak}, }; #[pyattr(name = "ref")] diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index ad6db12474..c1edb2739e 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -4,18 +4,18 @@ pub(crate) use _winapi::make_module; #[pymodule] mod _winapi { use crate::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::PyStrRef, common::windows::ToWideString, convert::{ToPyException, ToPyResult}, function::{ArgMapping, ArgSequence, OptionalArg}, stdlib::os::errno_err, windows::WindowsSysResult, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use std::ptr::{null, null_mut}; use windows::{ - core::PCWSTR, Win32::Foundation::{HANDLE, HINSTANCE, MAX_PATH}, + core::PCWSTR, }; use windows_sys::Win32::Foundation::{BOOL, HANDLE as RAW_HANDLE}; @@ -28,10 +28,20 @@ mod _winapi { ERROR_PIPE_CONNECTED, ERROR_SEM_TIMEOUT, GENERIC_READ, GENERIC_WRITE, STILL_ACTIVE, WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_OBJECT_0, WAIT_TIMEOUT, }, + Globalization::{ + LCMAP_FULLWIDTH, LCMAP_HALFWIDTH, LCMAP_HIRAGANA, LCMAP_KATAKANA, + LCMAP_LINGUISTIC_CASING, LCMAP_LOWERCASE, LCMAP_SIMPLIFIED_CHINESE, LCMAP_TITLECASE, + LCMAP_TRADITIONAL_CHINESE, LCMAP_UPPERCASE, + }, Storage::FileSystem::{ - FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, FILE_GENERIC_READ, - FILE_GENERIC_WRITE, FILE_TYPE_CHAR, FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE, - FILE_TYPE_UNKNOWN, OPEN_EXISTING, PIPE_ACCESS_DUPLEX, PIPE_ACCESS_INBOUND, SYNCHRONIZE, + COPYFILE2_CALLBACK_CHUNK_FINISHED, COPYFILE2_CALLBACK_CHUNK_STARTED, + COPYFILE2_CALLBACK_ERROR, COPYFILE2_CALLBACK_POLL_CONTINUE, + COPYFILE2_CALLBACK_STREAM_FINISHED, COPYFILE2_CALLBACK_STREAM_STARTED, + COPYFILE2_PROGRESS_CANCEL, COPYFILE2_PROGRESS_CONTINUE, COPYFILE2_PROGRESS_PAUSE, + COPYFILE2_PROGRESS_QUIET, COPYFILE2_PROGRESS_STOP, FILE_FLAG_FIRST_PIPE_INSTANCE, + FILE_FLAG_OVERLAPPED, FILE_GENERIC_READ, FILE_GENERIC_WRITE, FILE_TYPE_CHAR, + FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE, FILE_TYPE_UNKNOWN, OPEN_EXISTING, + PIPE_ACCESS_DUPLEX, PIPE_ACCESS_INBOUND, SYNCHRONIZE, }, System::{ Console::{STD_ERROR_HANDLE, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE}, @@ -53,6 +63,13 @@ mod _winapi { IDLE_PRIORITY_CLASS, INFINITE, NORMAL_PRIORITY_CLASS, PROCESS_DUP_HANDLE, REALTIME_PRIORITY_CLASS, STARTF_USESHOWWINDOW, STARTF_USESTDHANDLES, }, + WindowsProgramming::{ + COPY_FILE_ALLOW_DECRYPTED_DESTINATION, COPY_FILE_COPY_SYMLINK, + COPY_FILE_FAIL_IF_EXISTS, COPY_FILE_NO_BUFFERING, COPY_FILE_NO_OFFLOAD, + COPY_FILE_OPEN_SOURCE_FOR_WRITE, COPY_FILE_REQUEST_COMPRESSED_TRAFFIC, + COPY_FILE_REQUEST_SECURITY_PRIVILEGES, COPY_FILE_RESTARTABLE, + COPY_FILE_RESUME_FROM_PAUSE, + }, }, UI::WindowsAndMessaging::SW_HIDE, }; @@ -119,6 +136,11 @@ mod _winapi { Ok(HANDLE(target)) } + #[pyfunction] + fn GetACP() -> u32 { + unsafe { windows_sys::Win32::Globalization::GetACP() } + } + #[pyfunction] fn GetCurrentProcess() -> HANDLE { unsafe { windows::Win32::System::Threading::GetCurrentProcess() } @@ -137,6 +159,16 @@ mod _winapi { } } + #[pyfunction] + fn GetLastError() -> u32 { + unsafe { windows_sys::Win32::Foundation::GetLastError() } + } + + #[pyfunction] + fn GetVersion() -> u32 { + unsafe { windows_sys::Win32::System::SystemInformation::GetVersion() } + } + #[derive(FromArgs)] struct CreateProcessArgs { #[pyarg(positional)] @@ -249,6 +281,21 @@ mod _winapi { )) } + #[pyfunction] + fn OpenProcess( + desired_access: u32, + inherit_handle: bool, + process_id: u32, + ) -> windows_sys::Win32::Foundation::HANDLE { + unsafe { + windows_sys::Win32::System::Threading::OpenProcess( + desired_access, + BOOL::from(inherit_handle), + process_id, + ) + } + } + #[pyfunction] fn NeedCurrentDirectoryForExePath(exe_name: PyStrRef) -> bool { let exe_name = exe_name.as_str().to_wide_with_nul(); @@ -447,4 +494,24 @@ mod _winapi { let (path, _) = path.split_at(length as usize); Ok(String::from_utf16(path).unwrap()) } + + #[pyfunction] + fn OpenMutexW(desired_access: u32, inherit_handle: bool, name: u16) -> PyResult { + let handle = unsafe { + windows_sys::Win32::System::Threading::OpenMutexW( + desired_access, + BOOL::from(inherit_handle), + windows_sys::core::PCWSTR::from(name as _), + ) + }; + // if handle.is_invalid() { + // return Err(errno_err(vm)); + // } + Ok(handle) + } + + #[pyfunction] + fn ReleaseMutex(handle: isize) -> WindowsSysResult { + WindowsSysResult(unsafe { windows_sys::Win32::System::Threading::ReleaseMutex(handle) }) + } } diff --git a/vm/src/stdlib/winreg.rs b/vm/src/stdlib/winreg.rs index b368c43a3e..b0dbbfceec 100644 --- a/vm/src/stdlib/winreg.rs +++ b/vm/src/stdlib/winreg.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] -use crate::{builtins::PyModule, PyRef, VirtualMachine}; +use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let module = winreg::make_module(vm); @@ -29,10 +29,10 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { mod winreg { use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::{ - builtins::PyStrRef, convert::ToPyException, PyObjectRef, PyPayload, PyRef, PyResult, - TryFromObject, VirtualMachine, + PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::PyStrRef, + convert::ToPyException, }; - use ::winreg::{enums::RegType, RegKey, RegValue}; + use ::winreg::{RegKey, RegValue, enums::RegType}; use std::mem::ManuallyDrop; use std::{ffi::OsStr, io}; use windows_sys::Win32::Foundation; @@ -47,9 +47,13 @@ mod winreg { // value types #[pyattr] pub use windows_sys::Win32::System::Registry::{ - REG_BINARY, REG_DWORD, REG_DWORD_BIG_ENDIAN, REG_DWORD_LITTLE_ENDIAN, REG_EXPAND_SZ, - REG_FULL_RESOURCE_DESCRIPTOR, REG_LINK, REG_MULTI_SZ, REG_NONE, REG_QWORD, - REG_QWORD_LITTLE_ENDIAN, REG_RESOURCE_LIST, REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, + REG_BINARY, REG_CREATED_NEW_KEY, REG_DWORD, REG_DWORD_BIG_ENDIAN, REG_DWORD_LITTLE_ENDIAN, + REG_EXPAND_SZ, REG_FULL_RESOURCE_DESCRIPTOR, REG_LINK, REG_MULTI_SZ, REG_NONE, + REG_NOTIFY_CHANGE_ATTRIBUTES, REG_NOTIFY_CHANGE_LAST_SET, REG_NOTIFY_CHANGE_NAME, + REG_NOTIFY_CHANGE_SECURITY, REG_OPENED_EXISTING_KEY, REG_OPTION_BACKUP_RESTORE, + REG_OPTION_CREATE_LINK, REG_OPTION_NON_VOLATILE, REG_OPTION_OPEN_LINK, REG_OPTION_RESERVED, + REG_OPTION_VOLATILE, REG_QWORD, REG_QWORD_LITTLE_ENDIAN, REG_RESOURCE_LIST, + REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, REG_WHOLE_HIVE_VOLATILE, }; #[pyattr] @@ -98,7 +102,7 @@ mod winreg { #[pymethod(magic)] fn bool(&self) -> bool { - self.key().raw_handle() != 0 + !self.key().raw_handle().is_null() } #[pymethod(magic)] fn enter(zelf: PyRef) -> PyRef { diff --git a/vm/src/suggestion.rs b/vm/src/suggestion.rs index d46630a651..2bc9992d43 100644 --- a/vm/src/suggestion.rs +++ b/vm/src/suggestion.rs @@ -1,10 +1,10 @@ use crate::{ + AsObject, Py, PyObjectRef, VirtualMachine, builtins::{PyStr, PyStrRef}, exceptions::types::PyBaseExceptionRef, sliceable::SliceableSequenceOp, - AsObject, Py, PyObjectRef, VirtualMachine, }; -use rustpython_common::str::levenshtein::{levenshtein_distance, MOVE_COST}; +use rustpython_common::str::levenshtein::{MOVE_COST, levenshtein_distance}; use std::iter::ExactSizeIterator; const MAX_CANDIDATE_ITEMS: usize = 750; @@ -52,10 +52,8 @@ pub fn offer_suggestions(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> Optio calculate_suggestions(vm.dir(Some(obj)).ok()?.borrow_vec().iter(), &name) } else if exc.class().is(vm.ctx.exceptions.name_error) { let name = exc.as_object().get_attr("name", vm).unwrap(); - let mut tb = exc.traceback()?; - for traceback in tb.iter() { - tb = traceback; - } + let tb = exc.traceback()?; + let tb = tb.iter().last().unwrap_or(tb); let varnames = tb.frame.code.clone().co_varnames(vm); if let Some(suggestions) = calculate_suggestions(varnames.iter(), &name) { diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 75f113977d..de651580d9 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,5 +1,6 @@ use crate::{ - builtins::{type_::PointerSlot, PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyTypeRef}, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyTypeRef, type_::PointerSlot}, bytecode::ComparisonOperator, common::hash::PyHash, convert::{ToPyObject, ToPyResult}, @@ -12,7 +13,6 @@ use crate::{ PyNumberSlots, PySequence, PySequenceMethods, }, vm::Context, - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; use malachite_bigint::BigInt; diff --git a/vm/src/types/structseq.rs b/vm/src/types/structseq.rs index 516e2085af..a0b445ce7d 100644 --- a/vm/src/types/structseq.rs +++ b/vm/src/types/structseq.rs @@ -1,8 +1,8 @@ use crate::{ + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyTuple, PyTupleRef, PyType}, class::{PyClassImpl, StaticType}, vm::Context, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; #[pyclass] diff --git a/vm/src/types/zoo.rs b/vm/src/types/zoo.rs index 492b584e29..0d39648514 100644 --- a/vm/src/types/zoo.rs +++ b/vm/src/types/zoo.rs @@ -1,4 +1,5 @@ use crate::{ + Py, builtins::{ asyncgenerator, bool_, builtin_func, bytearray, bytes, classmethod, code, complex, coroutine, descriptor, dict, enumerate, filter, float, frame, function, generator, @@ -10,7 +11,6 @@ use crate::{ }, class::StaticType, vm::Context, - Py, }; /// Holder of references to builtin types. diff --git a/vm/src/utils.rs b/vm/src/utils.rs index 2c5ff79d3f..e2bc993686 100644 --- a/vm/src/utils.rs +++ b/vm/src/utils.rs @@ -1,7 +1,8 @@ use crate::{ + PyObjectRef, PyResult, VirtualMachine, builtins::PyStr, convert::{ToPyException, ToPyObject}, - PyObjectRef, PyResult, VirtualMachine, + exceptions::cstring_error, }; pub fn hash_iter<'a, I: IntoIterator>( @@ -17,22 +18,22 @@ impl ToPyObject for std::convert::Infallible { } } -pub trait ToCString { - fn to_cstring(&self, vm: &VirtualMachine) -> PyResult; -} - -impl ToCString for &str { - fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { - std::ffi::CString::new(*self).map_err(|err| err.to_pyexception(vm)) - } -} - -impl ToCString for PyStr { +pub trait ToCString: AsRef { fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { std::ffi::CString::new(self.as_ref()).map_err(|err| err.to_pyexception(vm)) } + fn ensure_no_nul(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.as_ref().as_bytes().contains(&b'\0') { + Err(cstring_error(vm)) + } else { + Ok(()) + } + } } +impl ToCString for &str {} +impl ToCString for PyStr {} + pub(crate) fn collection_repr<'a, I>( class_name: Option<&str>, prefix: &str, diff --git a/vm/src/version.rs b/vm/src/version.rs index 9a75f71142..8c42866a64 100644 --- a/vm/src/version.rs +++ b/vm/src/version.rs @@ -1,12 +1,12 @@ /* Several function to retrieve version information. */ -use chrono::{prelude::DateTime, Local}; +use chrono::{Local, prelude::DateTime}; use std::time::{Duration, UNIX_EPOCH}; -// = 3.12.0alpha +// = 3.13.0alpha pub const MAJOR: usize = 3; -pub const MINOR: usize = 12; +pub const MINOR: usize = 13; pub const MICRO: usize = 0; pub const RELEASELEVEL: &str = "alpha"; pub const RELEASELEVEL_N: usize = 0xA; diff --git a/vm/src/vm/compile.rs b/vm/src/vm/compile.rs index b7c888ab10..a14e986dac 100644 --- a/vm/src/vm/compile.rs +++ b/vm/src/vm/compile.rs @@ -1,9 +1,9 @@ use crate::{ + AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, builtins::{PyCode, PyDictRef}, compiler::{self, CompileError, CompileOpts}, convert::TryFromObject, scope::Scope, - AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, }; impl VirtualMachine { diff --git a/vm/src/vm/context.rs b/vm/src/vm/context.rs index f67035d0f3..54605704a5 100644 --- a/vm/src/vm/context.rs +++ b/vm/src/vm/context.rs @@ -1,6 +1,9 @@ use crate::{ + PyResult, VirtualMachine, builtins::{ - bytes, + PyBaseException, PyBytes, PyComplex, PyDict, PyDictRef, PyEllipsis, PyFloat, PyFrozenSet, + PyInt, PyIntRef, PyList, PyListRef, PyNone, PyNotImplemented, PyStr, PyStrInterned, + PyTuple, PyTupleRef, PyType, PyTypeRef, bytes, code::{self, PyCode}, descriptor::{ MemberGetter, MemberKind, MemberSetter, MemberSetterFunc, PyDescriptorOwned, @@ -9,9 +12,6 @@ use crate::{ getset::PyGetSet, object, pystr, type_::PyAttributes, - PyBaseException, PyBytes, PyComplex, PyDict, PyDictRef, PyEllipsis, PyFloat, PyFrozenSet, - PyInt, PyIntRef, PyList, PyListRef, PyNone, PyNotImplemented, PyStr, PyStrInterned, - PyTuple, PyTupleRef, PyType, PyTypeRef, }, class::{PyClassImpl, StaticType}, common::rc::PyRc, @@ -23,7 +23,6 @@ use crate::{ intern::{InternableString, MaybeInternedString, StringPool}, object::{Py, PyObjectPayload, PyObjectRef, PyPayload, PyRef}, types::{PyTypeFlags, PyTypeSlots, TypeZoo}, - PyResult, VirtualMachine, }; use malachite_bigint::BigInt; use num_complex::Complex64; @@ -62,7 +61,7 @@ macro_rules! declare_const_name { impl ConstName { unsafe fn new(pool: &StringPool, typ: &PyTypeRef) -> Self { Self { - $($name: pool.intern(stringify!($name), typ.clone()),)* + $($name: unsafe { pool.intern(stringify!($name), typ.clone()) },)* } } } diff --git a/vm/src/vm/interpreter.rs b/vm/src/vm/interpreter.rs index af28ad87ab..a375dbedc1 100644 --- a/vm/src/vm/interpreter.rs +++ b/vm/src/vm/interpreter.rs @@ -1,5 +1,5 @@ -use super::{setting::Settings, thread, Context, VirtualMachine}; -use crate::{stdlib::atexit, vm::PyBaseExceptionRef, PyResult}; +use super::{Context, VirtualMachine, setting::Settings, thread}; +use crate::{PyResult, stdlib::atexit, vm::PyBaseExceptionRef}; use std::sync::atomic::Ordering; /// The general interface for the VM @@ -140,8 +140,8 @@ impl Interpreter { mod tests { use super::*; use crate::{ - builtins::{int, PyStr}, PyObjectRef, + builtins::{PyStr, int}, }; use malachite_bigint::ToBigInt; diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 3d3351055f..05ac245a0a 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -15,12 +15,13 @@ mod vm_object; mod vm_ops; use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, builtins::{ + PyBaseExceptionRef, PyDictRef, PyInt, PyList, PyModule, PyStr, PyStrInterned, PyStrRef, + PyTypeRef, code::PyCode, pystr::AsPyStr, tuple::{PyTuple, PyTupleTyped}, - PyBaseExceptionRef, PyDictRef, PyInt, PyList, PyModule, PyStr, PyStrInterned, PyStrRef, - PyTypeRef, }, codecs::CodecsRegistry, common::{hash::HashSecret, lock::PyMutex, rc::PyRc}, @@ -33,12 +34,11 @@ use crate::{ scope::Scope, signal, stdlib, warn::WarningsState, - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, }; use crossbeam_utils::atomic::AtomicCell; #[cfg(unix)] use nix::{ - sys::signal::{kill, sigaction, SaFlags, SigAction, SigSet, Signal::SIGINT}, + sys::signal::{SaFlags, SigAction, SigSet, Signal::SIGINT, kill, sigaction}, unistd::getpid, }; use std::sync::atomic::AtomicBool; @@ -356,7 +356,9 @@ impl VirtualMachine { if self.state.settings.allow_external_library && cfg!(feature = "rustpython-compiler") { if let Err(e) = import::init_importlib_package(self, importlib) { - eprintln!("importlib initialization failed. This is critical for many complicated packages."); + eprintln!( + "importlib initialization failed. This is critical for many complicated packages." + ); self.print_exception(e); } } diff --git a/vm/src/vm/vm_new.rs b/vm/src/vm/vm_new.rs index 55f06e90da..bd42396299 100644 --- a/vm/src/vm/vm_new.rs +++ b/vm/src/vm/vm_new.rs @@ -1,15 +1,15 @@ use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyRef, builtins::{ + PyBaseException, PyBaseExceptionRef, PyDictRef, PyModule, PyStrRef, PyType, PyTypeRef, builtin_func::PyNativeFunction, descriptor::PyMethodDescriptor, tuple::{IntoPyTuple, PyTupleRef}, - PyBaseException, PyBaseExceptionRef, PyDictRef, PyModule, PyStrRef, PyType, PyTypeRef, }, convert::ToPyObject, function::{IntoPyNativeFn, PyMethodFlags}, scope::Scope, vm::VirtualMachine, - AsObject, Py, PyObject, PyObjectRef, PyRef, }; /// Collection of object creation helpers @@ -365,7 +365,15 @@ impl VirtualMachine { let actual_class = obj.class(); let actual_type = &*actual_class.name(); let expected_type = &*class.name(); - let msg = format!("Expected {msg} '{expected_type}' but '{actual_type}' found"); + let msg = format!("Expected {msg} '{expected_type}' but '{actual_type}' found."); + #[cfg(debug_assertions)] + let msg = if class.get_id() == actual_class.get_id() { + let mut msg = msg; + msg += " Did you forget to add `#[pyclass(with(Constructor))]`?"; + msg + } else { + msg + }; self.new_exception_msg(error_type.to_owned(), msg) } diff --git a/vm/src/vm/vm_object.rs b/vm/src/vm/vm_object.rs index b687eea34c..103078272d 100644 --- a/vm/src/vm/vm_object.rs +++ b/vm/src/vm/vm_object.rs @@ -1,6 +1,6 @@ use super::PyMethod; use crate::{ - builtins::{pystr::AsPyStr, PyBaseExceptionRef, PyList, PyStrInterned}, + builtins::{PyBaseExceptionRef, PyList, PyStrInterned, pystr::AsPyStr}, function::IntoFuncArgs, identifier, object::{AsObject, PyObject, PyObjectRef, PyResult}, @@ -85,11 +85,7 @@ impl VirtualMachine { obj.is(&self.ctx.none) } pub fn option_if_none(&self, obj: PyObjectRef) -> Option { - if self.is_none(&obj) { - None - } else { - Some(obj) - } + if self.is_none(&obj) { None } else { Some(obj) } } pub fn unwrap_or_none(&self, obj: Option) -> PyObjectRef { obj.unwrap_or_else(|| self.ctx.none()) diff --git a/vm/src/vm/vm_ops.rs b/vm/src/vm/vm_ops.rs index 56c32d2927..09f849a1a1 100644 --- a/vm/src/vm/vm_ops.rs +++ b/vm/src/vm/vm_ops.rs @@ -114,7 +114,7 @@ impl VirtualMachine { Ok(None) } else { Err(e) - } + }; } }; let hint = result @@ -298,8 +298,8 @@ impl VirtualMachine { } if let Some(slot_c) = class_c.slots.as_number.left_ternary_op(op_slot) { - if slot_a.is_some_and(|slot_a| slot_a != slot_c) - && slot_b.is_some_and(|slot_b| slot_b != slot_c) + if slot_a.is_some_and(|slot_a| !std::ptr::fn_addr_eq(slot_a, slot_c)) + && slot_b.is_some_and(|slot_b| !std::ptr::fn_addr_eq(slot_b, slot_c)) { let ret = slot_c(a, b, c, self)?; if !ret.is(&self.ctx.not_implemented) { diff --git a/vm/src/warn.rs b/vm/src/warn.rs index ab316c8559..b2055225a0 100644 --- a/vm/src/warn.rs +++ b/vm/src/warn.rs @@ -1,11 +1,11 @@ use crate::{ + AsObject, Context, Py, PyObjectRef, PyResult, VirtualMachine, builtins::{ PyDict, PyDictRef, PyListRef, PyStr, PyStrInterned, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }, convert::{IntoObject, TryFromObject}, types::PyComparisonOp, - AsObject, Context, Py, PyObjectRef, PyResult, VirtualMachine, }; pub struct WarningsState { @@ -17,15 +17,16 @@ pub struct WarningsState { impl WarningsState { fn create_filter(ctx: &Context) -> PyListRef { - ctx.new_list(vec![ctx - .new_tuple(vec![ + ctx.new_list(vec![ + ctx.new_tuple(vec![ ctx.new_str("__main__").into(), ctx.types.none_type.as_object().to_owned(), ctx.exceptions.warning.as_object().to_owned(), ctx.new_str("ACTION").into(), ctx.new_int(0).into(), ]) - .into()]) + .into(), + ]) } pub fn init_state(ctx: &Context) -> WarningsState { diff --git a/vm/src/windows.rs b/vm/src/windows.rs index f98bb1de63..f4f4dad0b3 100644 --- a/vm/src/windows.rs +++ b/vm/src/windows.rs @@ -1,11 +1,11 @@ use crate::common::fileutils::{ - windows::{get_file_information_by_name, FILE_INFO_BY_NAME_CLASS}, StatStruct, + windows::{FILE_INFO_BY_NAME_CLASS, get_file_information_by_name}, }; use crate::{ + PyObjectRef, PyResult, TryFromObject, VirtualMachine, convert::{ToPyObject, ToPyResult}, stdlib::os::errno_err, - PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use std::{ffi::OsStr, time::SystemTime}; use windows::Win32::Foundation::HANDLE; @@ -139,7 +139,7 @@ fn file_id(path: &OsStr) -> std::io::Result { use windows_sys::Win32::{ Foundation::HANDLE, Storage::FileSystem::{ - GetFileInformationByHandle, BY_HANDLE_FILE_INFORMATION, FILE_FLAG_BACKUP_SEMANTICS, + BY_HANDLE_FILE_INFORMATION, FILE_FLAG_BACKUP_SEMANTICS, GetFileInformationByHandle, }, }; diff --git a/vm/sre_engine/Cargo.toml b/vm/sre_engine/Cargo.toml index 28f98a3212..504652f3a7 100644 --- a/vm/sre_engine/Cargo.toml +++ b/vm/sre_engine/Cargo.toml @@ -10,7 +10,17 @@ rust-version.workspace = true repository.workspace = true license.workspace = true +[[bench]] +name = "benches" +harness = false + [dependencies] num_enum = { workspace = true } bitflags = { workspace = true } optional = "0.5" + +[dev-dependencies] +criterion = { workspace = true } + +[lints] +workspace = true diff --git a/vm/sre_engine/LICENSE b/vm/sre_engine/LICENSE index 7213274e0f..e2aa2ed952 100644 --- a/vm/sre_engine/LICENSE +++ b/vm/sre_engine/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 RustPython Team +Copyright (c) 2025 RustPython Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/vm/sre_engine/benches/benches.rs b/vm/sre_engine/benches/benches.rs index e31c73b0d0..ee49b036de 100644 --- a/vm/sre_engine/benches/benches.rs +++ b/vm/sre_engine/benches/benches.rs @@ -1,11 +1,9 @@ -#![feature(test)] - -extern crate test; -use test::Bencher; - use rustpython_sre_engine::{Request, State, StrDrive}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; + struct Pattern { + pattern: &'static str, code: &'static [u32], } @@ -25,52 +23,51 @@ impl Pattern { } } -#[bench] -fn benchmarks(b: &mut Bencher) { +fn basic(c: &mut Criterion) { // # test common prefix // pattern p1 = re.compile('Python|Perl') # , 'Perl'), # Alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p1 = Pattern { code: &[14, 8, 1, 4, 6, 1, 1, 80, 0, 16, 80, 7, 13, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 11, 9, 16, 101, 16, 114, 16, 108, 15, 2, 0, 1] }; + #[rustfmt::skip] let p1 = Pattern { pattern: "Python|Perl", code: &[14, 8, 1, 4, 6, 1, 1, 80, 0, 16, 80, 7, 13, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 11, 9, 16, 101, 16, 114, 16, 108, 15, 2, 0, 1] }; // END GENERATED // pattern p2 = re.compile('(Python|Perl)') #, 'Perl'), # Grouped alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p2 = Pattern { code: &[14, 8, 1, 4, 6, 1, 0, 80, 0, 17, 0, 16, 80, 7, 13, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 11, 9, 16, 101, 16, 114, 16, 108, 15, 2, 0, 17, 1, 1] }; + #[rustfmt::skip] let p2 = Pattern { pattern: "(Python|Perl)", code: &[14, 8, 1, 4, 6, 1, 0, 80, 0, 17, 0, 16, 80, 7, 13, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 11, 9, 16, 101, 16, 114, 16, 108, 15, 2, 0, 17, 1, 1] }; // END GENERATED // pattern p3 = re.compile('Python|Perl|Tcl') #, 'Perl'), # Alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p3 = Pattern { code: &[14, 9, 4, 3, 6, 16, 80, 16, 84, 0, 7, 15, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 22, 11, 16, 80, 16, 101, 16, 114, 16, 108, 15, 11, 9, 16, 84, 16, 99, 16, 108, 15, 2, 0, 1] }; + #[rustfmt::skip] let p3 = Pattern { pattern: "Python|Perl|Tcl", code: &[14, 9, 4, 3, 6, 16, 80, 16, 84, 0, 7, 15, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 22, 11, 16, 80, 16, 101, 16, 114, 16, 108, 15, 11, 9, 16, 84, 16, 99, 16, 108, 15, 2, 0, 1] }; // END GENERATED // pattern p4 = re.compile('(Python|Perl|Tcl)') #, 'Perl'), # Grouped alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p4 = Pattern { code: &[14, 9, 4, 3, 6, 16, 80, 16, 84, 0, 17, 0, 7, 15, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 22, 11, 16, 80, 16, 101, 16, 114, 16, 108, 15, 11, 9, 16, 84, 16, 99, 16, 108, 15, 2, 0, 17, 1, 1] }; + #[rustfmt::skip] let p4 = Pattern { pattern: "(Python|Perl|Tcl)", code: &[14, 9, 4, 3, 6, 16, 80, 16, 84, 0, 17, 0, 7, 15, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 22, 11, 16, 80, 16, 101, 16, 114, 16, 108, 15, 11, 9, 16, 84, 16, 99, 16, 108, 15, 2, 0, 17, 1, 1] }; // END GENERATED // pattern p5 = re.compile('(Python)\\1') #, 'PythonPython'), # Backreference // START GENERATED by generate_tests.py - #[rustfmt::skip] let p5 = Pattern { code: &[14, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 17, 1, 11, 0, 1] }; + #[rustfmt::skip] let p5 = Pattern { pattern: "(Python)\\1", code: &[14, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 17, 1, 11, 0, 1] }; // END GENERATED // pattern p6 = re.compile('([0a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # Disable the fastmap optimization // START GENERATED by generate_tests.py - #[rustfmt::skip] let p6 = Pattern { code: &[14, 4, 0, 2, 4294967295, 23, 31, 1, 4294967295, 17, 0, 13, 7, 16, 48, 22, 97, 122, 0, 24, 13, 0, 4294967295, 13, 8, 22, 97, 122, 22, 48, 57, 0, 1, 16, 44, 17, 1, 18, 1] }; + #[rustfmt::skip] let p6 = Pattern { pattern: "([0a-z][a-z0-9]*,)+", code: &[14, 4, 0, 2, 4294967295, 23, 31, 1, 4294967295, 17, 0, 13, 7, 16, 48, 22, 97, 122, 0, 24, 13, 0, 4294967295, 13, 8, 22, 97, 122, 22, 48, 57, 0, 1, 16, 44, 17, 1, 18, 1] }; // END GENERATED // pattern p7 = re.compile('([a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # A few sets // START GENERATED by generate_tests.py - #[rustfmt::skip] let p7 = Pattern { code: &[14, 4, 0, 2, 4294967295, 23, 29, 1, 4294967295, 17, 0, 13, 5, 22, 97, 122, 0, 24, 13, 0, 4294967295, 13, 8, 22, 97, 122, 22, 48, 57, 0, 1, 16, 44, 17, 1, 18, 1] }; + #[rustfmt::skip] let p7 = Pattern { pattern: "([a-z][a-z0-9]*,)+", code: &[14, 4, 0, 2, 4294967295, 23, 29, 1, 4294967295, 17, 0, 13, 5, 22, 97, 122, 0, 24, 13, 0, 4294967295, 13, 8, 22, 97, 122, 22, 48, 57, 0, 1, 16, 44, 17, 1, 18, 1] }; // END GENERATED // pattern p8 = re.compile('Python') #, 'Python'), # Simple text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p8 = Pattern { code: &[14, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 1] }; + #[rustfmt::skip] let p8 = Pattern { pattern: "Python", code: &[14, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 1] }; // END GENERATED // pattern p9 = re.compile('.*Python') #, 'Python'), # Bad text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p9 = Pattern { code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 1] }; + #[rustfmt::skip] let p9 = Pattern { pattern: ".*Python", code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 1] }; // END GENERATED // pattern p10 = re.compile('.*Python.*') #, 'Python'), # Worse text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p10 = Pattern { code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 24, 5, 0, 4294967295, 2, 1, 1] }; + #[rustfmt::skip] let p10 = Pattern { pattern: ".*Python.*", code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 24, 5, 0, 4294967295, 2, 1, 1] }; // END GENERATED // pattern p11 = re.compile('.*(Python)') #, 'Python'), # Bad text literal with grouping // START GENERATED by generate_tests.py - #[rustfmt::skip] let p11 = Pattern { code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 17, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 17, 1, 1] }; + #[rustfmt::skip] let p11 = Pattern { pattern: ".*(Python)", code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 17, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 17, 1, 1] }; // END GENERATED let tests = [ @@ -87,25 +84,33 @@ fn benchmarks(b: &mut Bencher) { (p11, "Python"), ]; - b.iter(move || { - for (p, s) in &tests { - let (req, mut state) = p.state(s.clone()); - assert!(state.search(req)); - let (req, mut state) = p.state(s.clone()); - assert!(state.pymatch(&req)); - let (mut req, mut state) = p.state(s.clone()); - req.match_all = true; - assert!(state.pymatch(&req)); - let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); - let (req, mut state) = p.state_range(s2.as_str(), 0..usize::MAX); - assert!(state.search(req)); - let (req, mut state) = p.state_range(s2.as_str(), 10000..usize::MAX); - assert!(state.pymatch(&req)); - let (req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); - assert!(state.pymatch(&req)); - let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); - req.match_all = true; - assert!(state.pymatch(&req)); - } - }) + let mut group = c.benchmark_group("basic"); + + for (p, s) in tests { + group.bench_with_input(BenchmarkId::new(p.pattern, s), s, |b, s| { + b.iter(|| { + let (req, mut state) = p.state(s); + assert!(state.search(req)); + let (req, mut state) = p.state(s); + assert!(state.pymatch(&req)); + let (mut req, mut state) = p.state(s); + req.match_all = true; + assert!(state.pymatch(&req)); + let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); + let (req, mut state) = p.state_range(s2.as_str(), 0..usize::MAX); + assert!(state.search(req)); + let (req, mut state) = p.state_range(s2.as_str(), 10000..usize::MAX); + assert!(state.pymatch(&req)); + let (req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); + assert!(state.pymatch(&req)); + let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); + req.match_all = true; + assert!(state.pymatch(&req)); + }); + }); + } } + +criterion_group!(benches, basic); + +criterion_main!(benches); diff --git a/vm/sre_engine/generate_tests.py b/vm/sre_engine/generate_tests.py index 8adf043f29..6621c56813 100644 --- a/vm/sre_engine/generate_tests.py +++ b/vm/sre_engine/generate_tests.py @@ -1,9 +1,6 @@ import os from pathlib import Path import re -import sre_constants -import sre_compile -import sre_parse import json from itertools import chain @@ -11,13 +8,13 @@ sre_engine_magic = int(m.group(1)) del m -assert sre_constants.MAGIC == sre_engine_magic +assert re._constants.MAGIC == sre_engine_magic class CompiledPattern: @classmethod def compile(cls, pattern, flags=0): - p = sre_parse.parse(pattern) - code = sre_compile._code(p, flags) + p = re._parser.parse(pattern) + code = re._compiler._code(p, flags) self = cls() self.pattern = pattern self.code = code @@ -28,12 +25,32 @@ def compile(cls, pattern, flags=0): setattr(CompiledPattern, k, v) +class EscapeRustStr: + hardcoded = { + ord('\r'): '\\r', + ord('\t'): '\\t', + ord('\r'): '\\r', + ord('\n'): '\\n', + ord('\\'): '\\\\', + ord('\''): '\\\'', + ord('\"'): '\\\"', + } + @classmethod + def __class_getitem__(cls, ch): + if (rpl := cls.hardcoded.get(ch)) is not None: + return rpl + if ch in range(0x20, 0x7f): + return ch + return f"\\u{{{ch:x}}}" +def rust_str(s): + return '"' + s.translate(EscapeRustStr) + '"' + # matches `// pattern {varname} = re.compile(...)` pattern_pattern = re.compile(r"^((\s*)\/\/\s*pattern\s+(\w+)\s+=\s+(.+?))$(?:.+?END GENERATED)?", re.M | re.S) def replace_compiled(m): line, indent, varname, pattern = m.groups() pattern = eval(pattern, {"re": CompiledPattern}) - pattern = f"Pattern {{ code: &{json.dumps(pattern.code)} }}" + pattern = f"Pattern {{ pattern: {rust_str(pattern.pattern)}, code: &{json.dumps(pattern.code)} }}" return f'''{line} {indent}// START GENERATED by generate_tests.py {indent}#[rustfmt::skip] let {varname} = {pattern}; diff --git a/vm/sre_engine/src/engine.rs b/vm/sre_engine/src/engine.rs index 77f4ccdf36..3425da3371 100644 --- a/vm/sre_engine/src/engine.rs +++ b/vm/sre_engine/src/engine.rs @@ -5,7 +5,7 @@ use crate::string::{ is_uni_word, is_word, lower_ascii, lower_locate, lower_unicode, upper_locate, upper_unicode, }; -use super::{SreAtCode, SreCatCode, SreInfo, SreOpcode, StrDrive, StringCursor, MAXREPEAT}; +use super::{MAXREPEAT, SreAtCode, SreCatCode, SreInfo, SreOpcode, StrDrive, StringCursor}; use optional::Optioned; use std::{convert::TryFrom, ptr::null}; diff --git a/vm/sre_engine/src/lib.rs b/vm/sre_engine/src/lib.rs index fd9f367dc6..08c21de9df 100644 --- a/vm/sre_engine/src/lib.rs +++ b/vm/sre_engine/src/lib.rs @@ -2,7 +2,7 @@ pub mod constants; pub mod engine; pub mod string; -pub use constants::{SreAtCode, SreCatCode, SreFlag, SreInfo, SreOpcode, SRE_MAGIC}; +pub use constants::{SRE_MAGIC, SreAtCode, SreCatCode, SreFlag, SreInfo, SreOpcode}; pub use engine::{Request, SearchIter, State}; pub use string::{StrDrive, StringCursor}; diff --git a/vm/sre_engine/src/string.rs b/vm/sre_engine/src/string.rs index ca8b3179dc..77e0f3e772 100644 --- a/vm/sre_engine/src/string.rs +++ b/vm/sre_engine/src/string.rs @@ -157,8 +157,8 @@ impl StrDrive for &str { #[inline] unsafe fn next_code_point(ptr: &mut *const u8) -> u32 { // Decode UTF-8 - let x = **ptr; - *ptr = ptr.offset(1); + let x = unsafe { **ptr }; + *ptr = unsafe { ptr.offset(1) }; if x < 128 { return x as u32; @@ -170,26 +170,26 @@ unsafe fn next_code_point(ptr: &mut *const u8) -> u32 { let init = utf8_first_byte(x, 2); // SAFETY: `bytes` produces an UTF-8-like string, // so the iterator must produce a value here. - let y = **ptr; - *ptr = ptr.offset(1); + let y = unsafe { **ptr }; + *ptr = unsafe { ptr.offset(1) }; let mut ch = utf8_acc_cont_byte(init, y); if x >= 0xE0 { // [[x y z] w] case // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid // SAFETY: `bytes` produces an UTF-8-like string, // so the iterator must produce a value here. - let z = **ptr; - *ptr = ptr.offset(1); + let z = unsafe { **ptr }; + *ptr = unsafe { ptr.offset(1) }; let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z); - ch = init << 12 | y_z; + ch = (init << 12) | y_z; if x >= 0xF0 { // [x y z w] case // use only the lower 3 bits of `init` // SAFETY: `bytes` produces an UTF-8-like string, // so the iterator must produce a value here. - let w = **ptr; - *ptr = ptr.offset(1); - ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w); + let w = unsafe { **ptr }; + *ptr = unsafe { ptr.offset(1) }; + ch = ((init & 7) << 18) | utf8_acc_cont_byte(y_z, w); } } @@ -205,8 +205,8 @@ unsafe fn next_code_point(ptr: &mut *const u8) -> u32 { #[inline] unsafe fn next_code_point_reverse(ptr: &mut *const u8) -> u32 { // Decode UTF-8 - *ptr = ptr.offset(-1); - let w = match **ptr { + *ptr = unsafe { ptr.offset(-1) }; + let w = match unsafe { **ptr } { next_byte if next_byte < 128 => return next_byte as u32, back_byte => back_byte, }; @@ -216,20 +216,20 @@ unsafe fn next_code_point_reverse(ptr: &mut *const u8) -> u32 { let mut ch; // SAFETY: `bytes` produces an UTF-8-like string, // so the iterator must produce a value here. - *ptr = ptr.offset(-1); - let z = **ptr; + *ptr = unsafe { ptr.offset(-1) }; + let z = unsafe { **ptr }; ch = utf8_first_byte(z, 2); if utf8_is_cont_byte(z) { // SAFETY: `bytes` produces an UTF-8-like string, // so the iterator must produce a value here. - *ptr = ptr.offset(-1); - let y = **ptr; + *ptr = unsafe { ptr.offset(-1) }; + let y = unsafe { **ptr }; ch = utf8_first_byte(y, 3); if utf8_is_cont_byte(y) { // SAFETY: `bytes` produces an UTF-8-like string, // so the iterator must produce a value here. - *ptr = ptr.offset(-1); - let x = **ptr; + *ptr = unsafe { ptr.offset(-1) }; + let x = unsafe { **ptr }; ch = utf8_first_byte(x, 4); ch = utf8_acc_cont_byte(ch, y); } diff --git a/vm/sre_engine/tests/tests.rs b/vm/sre_engine/tests/tests.rs index 53494c5e3d..0946fd64ca 100644 --- a/vm/sre_engine/tests/tests.rs +++ b/vm/sre_engine/tests/tests.rs @@ -1,6 +1,7 @@ use rustpython_sre_engine::{Request, State, StrDrive}; struct Pattern { + pattern: &'static str, code: &'static [u32], } @@ -16,7 +17,7 @@ impl Pattern { fn test_2427() { // pattern lookbehind = re.compile(r'(?x)++x') // START GENERATED by generate_tests.py - #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 2, 4294967295, 28, 8, 1, 4294967295, 27, 4, 16, 120, 1, 1, 16, 120, 1] }; + #[rustfmt::skip] let p = Pattern { pattern: "(?>x)++x", code: &[14, 4, 0, 2, 4294967295, 28, 8, 1, 4294967295, 27, 4, 16, 120, 1, 1, 16, 120, 1] }; // END GENERATED let (req, mut state) = p.state("xxx"); assert!(!state.pymatch(&req)); @@ -156,7 +157,7 @@ fn test_possessive_atomic_group() { fn test_bug_20998() { // pattern p = re.compile('[a-c]+', re.I) // START GENERATED by generate_tests.py - #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 1, 4294967295, 24, 10, 1, 4294967295, 39, 5, 22, 97, 99, 0, 1, 1] }; + #[rustfmt::skip] let p = Pattern { pattern: "[a-c]+", code: &[14, 4, 0, 1, 4294967295, 24, 10, 1, 4294967295, 39, 5, 22, 97, 99, 0, 1, 1] }; // END GENERATED let (mut req, mut state) = p.state("ABC"); req.match_all = true; @@ -168,7 +169,7 @@ fn test_bug_20998() { fn test_bigcharset() { // pattern p = re.compile('[a-z]*', re.I) // START GENERATED by generate_tests.py - #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 0, 4294967295, 24, 97, 0, 4294967295, 39, 92, 10, 3, 33685760, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 0, 0, 0, 134217726, 0, 0, 0, 0, 0, 131072, 0, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1] }; + #[rustfmt::skip] let p = Pattern { pattern: "[a-z]*", code: &[14, 4, 0, 0, 4294967295, 24, 97, 0, 4294967295, 39, 92, 10, 3, 33685760, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 0, 0, 0, 134217726, 0, 0, 0, 0, 0, 131072, 0, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1] }; // END GENERATED let (req, mut state) = p.state("x "); assert!(state.pymatch(&req)); @@ -178,4 +179,7 @@ fn test_bigcharset() { #[test] fn test_search_nonascii() { // pattern p = re.compile('\xe0+') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { pattern: "\u{e0}+", code: &[14, 4, 0, 1, 4294967295, 24, 6, 1, 4294967295, 16, 224, 1, 1] }; + // END GENERATED } diff --git a/wasm/lib/.cargo/config.toml b/wasm/lib/.cargo/config.toml new file mode 100644 index 0000000000..ce1e7c694a --- /dev/null +++ b/wasm/lib/.cargo/config.toml @@ -0,0 +1,5 @@ +[build] +target = "wasm32-unknown-unknown" + +[target.wasm32-unknown-unknown] +rustflags = ["--cfg=getrandom_backend=\"wasm_js\""] diff --git a/wasm/lib/Cargo.toml b/wasm/lib/Cargo.toml index 1e5c37f4ef..4703cb9f4a 100644 --- a/wasm/lib/Cargo.toml +++ b/wasm/lib/Cargo.toml @@ -28,6 +28,9 @@ rustpython-parser = { workspace = true } serde = { workspace = true } wasm-bindgen = { workspace = true } +# remove once getrandom 0.2 is no longer otherwise in the dependency tree +getrandom = { version = "0.2", features = ["js"] } + console_error_panic_hook = "0.1" js-sys = "0.3" serde-wasm-bindgen = "0.3.1" @@ -47,4 +50,4 @@ web-sys = { version = "0.3", features = [ wasm-opt = false#["-O1"] [lints] -workspace = true \ No newline at end of file +workspace = true diff --git a/wasm/lib/src/browser_module.rs b/wasm/lib/src/browser_module.rs index 86a3ea3ab3..f8d1b2ebc3 100644 --- a/wasm/lib/src/browser_module.rs +++ b/wasm/lib/src/browser_module.rs @@ -7,14 +7,14 @@ mod _browser { use crate::{convert, js_module::PyPromise, vm_class::weak_vm, wasm_builtins::window}; use js_sys::Promise; use rustpython_vm::{ + PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyDictRef, PyStrRef}, class::PyClassImpl, convert::ToPyObject, function::{ArgCallable, OptionalArg}, import::import_source, - PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; - use wasm_bindgen::{prelude::*, JsCast}; + use wasm_bindgen::{JsCast, prelude::*}; use wasm_bindgen_futures::JsFuture; enum FetchResponseFormat { diff --git a/wasm/lib/src/convert.rs b/wasm/lib/src/convert.rs index 0bf1d21c95..d4fb068e6c 100644 --- a/wasm/lib/src/convert.rs +++ b/wasm/lib/src/convert.rs @@ -1,17 +1,18 @@ #![allow(clippy::empty_docs)] // TODO: remove it later. false positive by wasm-bindgen generated code use crate::js_module; -use crate::vm_class::{stored_vm_from_wasm, WASMVirtualMachine}; +use crate::vm_class::{WASMVirtualMachine, stored_vm_from_wasm}; use js_sys::{Array, ArrayBuffer, Object, Promise, Reflect, SyntaxError, Uint8Array}; -use rustpython_parser::{lexer::LexicalErrorType, ParseErrorType}; +use rustpython_parser::{ParseErrorType, lexer::LexicalErrorType}; use rustpython_vm::{ + AsObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, builtins::PyBaseExceptionRef, compiler::{CompileError, CompileErrorType}, exceptions, function::{ArgBytesLike, FuncArgs}, - py_serde, AsObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, + py_serde, }; -use wasm_bindgen::{closure::Closure, prelude::*, JsCast}; +use wasm_bindgen::{JsCast, closure::Closure, prelude::*}; #[wasm_bindgen(inline_js = r" export class PyError extends Error { diff --git a/wasm/lib/src/js_module.rs b/wasm/lib/src/js_module.rs index f0c5378c35..e25499df4d 100644 --- a/wasm/lib/src/js_module.rs +++ b/wasm/lib/src/js_module.rs @@ -5,21 +5,21 @@ use rustpython_vm::VirtualMachine; mod _js { use crate::{ convert, - vm_class::{stored_vm_from_wasm, WASMVirtualMachine}, + vm_class::{WASMVirtualMachine, stored_vm_from_wasm}, weak_vm, }; use js_sys::{Array, Object, Promise, Reflect}; use rustpython_vm::{ + Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyFloat, PyStrRef, PyType, PyTypeRef}, convert::{IntoObject, ToPyObject}, function::{ArgCallable, OptionalArg, OptionalOption, PosArgs}, protocol::PyIterReturn, types::{IterNext, Representable, SelfIter}, - Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; use std::{cell, fmt, future}; - use wasm_bindgen::{closure::Closure, prelude::*, JsCast}; - use wasm_bindgen_futures::{future_to_promise, JsFuture}; + use wasm_bindgen::{JsCast, closure::Closure, prelude::*}; + use wasm_bindgen_futures::{JsFuture, future_to_promise}; #[wasm_bindgen(inline_js = " export function has_prop(target, prop) { return prop in Object(target); } diff --git a/wasm/lib/src/vm_class.rs b/wasm/lib/src/vm_class.rs index 0c62bb32e7..c04877f7e3 100644 --- a/wasm/lib/src/vm_class.rs +++ b/wasm/lib/src/vm_class.rs @@ -5,10 +5,10 @@ use crate::{ }; use js_sys::{Object, TypeError}; use rustpython_vm::{ + Interpreter, PyObjectRef, PyPayload, PyRef, PyResult, Settings, VirtualMachine, builtins::{PyModule, PyWeak}, compiler::Mode, scope::Scope, - Interpreter, PyObjectRef, PyPayload, PyRef, PyResult, Settings, VirtualMachine, }; use std::{ cell::RefCell, @@ -233,7 +233,7 @@ impl WASMVirtualMachine { #[wasm_bindgen(js_name = addToScope)] pub fn add_to_scope(&self, name: String, value: JsValue) -> Result<(), JsValue> { - self.with_vm(move |vm, StoredVirtualMachine { ref scope, .. }| { + self.with_vm(move |vm, StoredVirtualMachine { scope, .. }| { let value = convert::js_to_py(vm, value); scope.globals.set_item(&name, value, vm).into_js(vm) })? @@ -335,7 +335,7 @@ impl WASMVirtualMachine { mode: Mode, source_path: Option, ) -> Result { - self.with_vm(|vm, StoredVirtualMachine { ref scope, .. }| { + self.with_vm(|vm, StoredVirtualMachine { scope, .. }| { let source_path = source_path.unwrap_or_else(|| "".to_owned()); let code = vm.compile(source, mode, source_path); let code = code.map_err(convert::syntax_err)?; diff --git a/wasm/lib/src/wasm_builtins.rs b/wasm/lib/src/wasm_builtins.rs index 59d7880af4..79423fc250 100644 --- a/wasm/lib/src/wasm_builtins.rs +++ b/wasm/lib/src/wasm_builtins.rs @@ -4,7 +4,7 @@ //! desktop. //! Implements functions listed here: https://docs.python.org/3/library/builtins.html. -use rustpython_vm::{builtins::PyStrRef, PyObjectRef, PyRef, PyResult, VirtualMachine}; +use rustpython_vm::{PyObjectRef, PyRef, PyResult, VirtualMachine, builtins::PyStrRef}; use web_sys::{self, console}; pub(crate) fn window() -> web_sys::Window { diff --git a/wasm/wasm-unknown-test/.cargo/config.toml b/wasm/wasm-unknown-test/.cargo/config.toml new file mode 100644 index 0000000000..f86ad96761 --- /dev/null +++ b/wasm/wasm-unknown-test/.cargo/config.toml @@ -0,0 +1,5 @@ +[build] +target = "wasm32-unknown-unknown" + +[target.wasm32-unknown-unknown] +rustflags = ["--cfg=getrandom_backend=\"custom\""] diff --git a/wasm/wasm-unknown-test/Cargo.toml b/wasm/wasm-unknown-test/Cargo.toml index f5e0b55786..ca8b15cfc5 100644 --- a/wasm/wasm-unknown-test/Cargo.toml +++ b/wasm/wasm-unknown-test/Cargo.toml @@ -8,6 +8,7 @@ crate-type = ["cdylib"] [dependencies] getrandom = { version = "0.2.12", features = ["custom"] } +getrandom_03 = { package = "getrandom", version = "0.3" } rustpython-vm = { path = "../../vm", default-features = false, features = ["compiler"] } [workspace] diff --git a/wasm/wasm-unknown-test/src/lib.rs b/wasm/wasm-unknown-test/src/lib.rs index fd043aea3a..cfdc445574 100644 --- a/wasm/wasm-unknown-test/src/lib.rs +++ b/wasm/wasm-unknown-test/src/lib.rs @@ -14,3 +14,11 @@ fn getrandom_always_fail(_buf: &mut [u8]) -> Result<(), getrandom::Error> { } getrandom::register_custom_getrandom!(getrandom_always_fail); + +#[unsafe(no_mangle)] +unsafe extern "Rust" fn __getrandom_v03_custom( + _dest: *mut u8, + _len: usize, +) -> Result<(), getrandom_03::Error> { + Err(getrandom_03::Error::UNSUPPORTED) +} diff --git a/whats_left.py b/whats_left.py index 4f087f89af..30b1de088e 100755 --- a/whats_left.py +++ b/whats_left.py @@ -35,8 +35,8 @@ implementation = platform.python_implementation() if implementation != "CPython": sys.exit(f"whats_left.py must be run under CPython, got {implementation} instead") -if sys.version_info[:2] < (3, 12): - sys.exit(f"whats_left.py must be run under CPython 3.12 or newer, got {implementation} {sys.version} instead") +if sys.version_info[:2] < (3, 13): + sys.exit(f"whats_left.py must be run under CPython 3.13 or newer, got {implementation} {sys.version} instead") def parse_args(): parser = argparse.ArgumentParser(description="Process some integers.")