diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..253a13aca --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: ["methane"] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: "pypi/PyMySQL" # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 3e0fbe826..000000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,11 +0,0 @@ -This project is maintained one busy person with a frail wife and an infant daughter. -My time and energy is a very limited resource. I'm not a teacher or free tech support. -Don't ask a question here. Don't file an issue until you believe it's a not a problem with your code. -Search for friendly volunteers who can teach you or review your code on ML or Q&A sites. - -See also: https://medium.com/@methane/why-you-must-not-ask-questions-on-github-issues-51d741d83fde - - -If you're sure it's PyMySQL's issue, report the complete steps to reproduce, from creating database. - -I don't have time to investigate your issue from an incomplete code snippet. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..f2bd4d300 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,39 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Complete steps to reproduce the behavior: + +Schema: + +``` +CREATE DATABASE ... +CREATE TABLE ... +``` + +Code: + +```py +import pymysql +con = pymysql.connect(...) +``` + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Environment** + - OS: [e.g. Windows, Linux] + - Server and version: [e.g. MySQL 8.0.19, MariaDB] + - PyMySQL version: + +**Additional context** +Add any other context about the problem here. diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 000000000..df49979ea --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,62 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ main ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ main ] + schedule: + - cron: '34 7 * * 2' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + + strategy: + fail-fast: false + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: "python" + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v3 + + # â„šī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # âœī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/codesee-arch-diagram.yml b/.github/workflows/codesee-arch-diagram.yml new file mode 100644 index 000000000..806d41d12 --- /dev/null +++ b/.github/workflows/codesee-arch-diagram.yml @@ -0,0 +1,23 @@ +# This workflow was added by CodeSee. Learn more at https://codesee.io/ +# This is v2.0 of this workflow file +on: + push: + branches: + - main + pull_request_target: + types: [opened, synchronize, reopened] + +name: CodeSee + +permissions: read-all + +jobs: + codesee: + runs-on: ubuntu-latest + continue-on-error: true + name: Analyze the repo with CodeSee + steps: + - uses: Codesee-io/codesee-action@v2 + with: + codesee-token: ${{ secrets.CODESEE_ARCH_DIAG_API_TOKEN }} + codesee-url: https://app.codesee.io diff --git a/.github/workflows/django.yaml b/.github/workflows/django.yaml new file mode 100644 index 000000000..5c4609543 --- /dev/null +++ b/.github/workflows/django.yaml @@ -0,0 +1,66 @@ +name: Django test + +on: + push: + # branches: ["main"] + # pull_request: + +jobs: + django-test: + name: "Run Django LTS test suite" + runs-on: ubuntu-latest + # There are some known difference between MySQLdb and PyMySQL. + continue-on-error: true + env: + PIP_NO_PYTHON_VERSION_WARNING: 1 + PIP_DISABLE_PIP_VERSION_CHECK: 1 + # DJANGO_VERSION: "3.2.19" + strategy: + fail-fast: false + matrix: + include: + # Django 3.2.9+ supports Python 3.10 + # https://docs.djangoproject.com/ja/3.2/releases/3.2/ + - django: "3.2.19" + python: "3.10" + + - django: "4.2.1" + python: "3.11" + + steps: + - name: Start MySQL + run: | + sudo systemctl start mysql.service + mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -uroot -proot mysql + mysql -uroot -proot -e "set global innodb_flush_log_at_trx_commit=0;" + mysql -uroot -proot -e "CREATE USER 'scott'@'%' IDENTIFIED BY 'tiger'; GRANT ALL ON *.* TO scott;" + mysql -uroot -proot -e "CREATE DATABASE django_default; CREATE DATABASE django_other;" + + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install mysqlclient + run: | + #pip install mysqlclient # Use stable version + pip install .[rsa] + + - name: Setup Django + run: | + sudo apt-get install libmemcached-dev + wget https://github.com/django/django/archive/${{ matrix.django }}.tar.gz + tar xf ${{ matrix.django }}.tar.gz + cp ci/test_mysql.py django-${{ matrix.django }}/tests/ + cd django-${{ matrix.django }} + pip install . -r tests/requirements/py3.txt + + - name: Run Django test + run: | + cd django-${{ matrix.django }}/tests/ + # test_runner does not using our test_mysql.py + # We can't run whole django test suite for now. + # Run olly backends test + DJANGO_SETTINGS_MODULE=test_mysql python runtests.py backends diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 000000000..269211c25 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,25 @@ +name: Lint + +on: + push: + branches: ["main"] + paths: + - '**.py' + pull_request: + paths: + - '**.py' + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v4 + + - name: lint + uses: chartboost/ruff-action@v1 + + - name: check format + uses: chartboost/ruff-action@v1 + with: + args: "format --diff" diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml new file mode 100644 index 000000000..21449e3b8 --- /dev/null +++ b/.github/workflows/lock.yml @@ -0,0 +1,17 @@ +name: 'Lock Threads' + +on: + schedule: + - cron: '30 9 * * 1' + +permissions: + issues: write + pull-requests: write + +jobs: + lock-threads: + if: github.repository == 'PyMySQL/PyMySQL' + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v5 + diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 000000000..6d59d8c4b --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,109 @@ +name: Test + +on: + push: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +env: + FORCE_COLOR: 1 + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - db: "mariadb:10.4" + py: "3.8" + + - db: "mariadb:10.5" + py: "3.7" + + - db: "mariadb:10.6" + py: "3.11" + + - db: "mariadb:10.6" + py: "3.12" + + - db: "mariadb:lts" + py: "3.9" + + - db: "mysql:5.7" + py: "pypy-3.8" + + - db: "mysql:8.0" + py: "3.9" + mysql_auth: true + + - db: "mysql:8.0" + py: "3.10" + + services: + mysql: + image: "${{ matrix.db }}" + ports: + - 3306:3306 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: yes + options: "--name=mysqld" + volumes: + - /run/mysqld:/run/mysqld + + steps: + - uses: actions/checkout@v4 + + - name: Workaround MySQL container permissions + if: startsWith(matrix.db, 'mysql') + run: | + sudo chown 999:999 /run/mysqld + /usr/bin/docker ps --all --filter status=exited --no-trunc --format "{{.ID}}" | xargs -r /usr/bin/docker start + + - name: Set up Python ${{ matrix.py }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.py }} + allow-prereleases: true + cache: 'pip' + cache-dependency-path: 'requirements-dev.txt' + + - name: Install dependency + run: | + pip install --upgrade -r requirements-dev.txt + + - name: Set up MySQL + run: | + while : + do + sleep 1 + mysql -h127.0.0.1 -uroot -e 'select version()' && break + done + mysql -h127.0.0.1 -uroot -e "SET GLOBAL local_infile=on" + mysql -h127.0.0.1 -uroot --comments < ci/docker-entrypoint-initdb.d/init.sql + mysql -h127.0.0.1 -uroot --comments < ci/docker-entrypoint-initdb.d/mysql.sql + mysql -h127.0.0.1 -uroot --comments < ci/docker-entrypoint-initdb.d/mariadb.sql + cp ci/docker.json pymysql/tests/databases.json + + - name: Run test + run: | + pytest -v --cov --cov-config .coveragerc pymysql + pytest -v --cov-append --cov-config .coveragerc --doctest-modules pymysql/converters.py + + - name: Run MySQL8 auth test + if: ${{ matrix.mysql_auth }} + run: | + docker cp mysqld:/var/lib/mysql/public_key.pem "${HOME}" + docker cp mysqld:/var/lib/mysql/ca.pem "${HOME}" + docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}" + docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}" + docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}" + pytest -v --cov-append --cov-config .coveragerc tests/test_auth.py; + + - name: Upload coverage reports to Codecov + if: github.repository == 'PyMySQL/PyMySQL' + uses: codecov/codecov-action@v4 diff --git a/.gitignore b/.gitignore index 98f4d45c8..09a5654fb 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ /pymysql/tests/databases.json __pycache__ Pipfile.lock +pdm.lock diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..59fdb65df --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.12" + +python: + install: + - requirements: docs/requirements.txt + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/source/conf.py diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index f4a7cc74e..000000000 --- a/.travis.yml +++ /dev/null @@ -1,67 +0,0 @@ -# vim: sw=2 ts=2 sts=2 expandtab - -sudo: required -language: python -services: - - docker - -cache: pip - -matrix: - include: - - env: - - DB=mariadb:5.5 - python: "3.5" - - env: - - DB=mariadb:10.0 - python: "3.6" - - env: - - DB=mariadb:10.1 - python: "pypy" - - env: - - DB=mariadb:10.2 - python: "2.7" - - env: - - DB=mariadb:10.3 - python: "3.7-dev" - - env: - - DB=mysql:5.5 - python: "3.5" - - env: - - DB=mysql:5.6 - python: "3.6" - - env: - - DB=mysql:5.7 - python: "3.4" - - env: - - DB=mysql:8.0 - - TEST_AUTH=yes - python: "3.7-dev" - - env: - - DB=mysql:8.0 - - TEST_AUTH=yes - python: "2.7" - -# different py version from 5.6 and 5.7 as cache seems to be based on py version -# http://dev.mysql.com/downloads/mysql/5.7.html has latest development release version -# really only need libaio1 for DB builds however libaio-dev is whitelisted for container builds and liaio1 isn't -install: - - pip install -U coveralls unittest2 coverage cryptography pytest pytest-cov - -before_script: - - ./.travis/initializedb.sh - - python -VV - - rm -f ~/.my.cnf # set in .travis.initialize.db.sh for the above commands - we should be using database.json however - - export COVERALLS_PARALLEL=true - -script: - - coverage run ./runtests.py - - if [ "${TEST_AUTH}" = "yes" ]; - then pytest -v --cov-config .coveragerc tests; - fi - - if [ ! -z "${DB}" ]; - then docker logs mysqld; - fi - -after_success: - - coveralls diff --git a/.travis/database.json b/.travis/database.json deleted file mode 100644 index ab1f60a3a..000000000 --- a/.travis/database.json +++ /dev/null @@ -1,4 +0,0 @@ -[ - {"host": "localhost", "unix_socket": "/var/run/mysqld/mysqld.sock", "user": "root", "passwd": "", "db": "test1", "use_unicode": true, "local_infile": true}, - {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "db": "test2" } -] diff --git a/.travis/docker.json b/.travis/docker.json deleted file mode 100644 index b851fb6da..000000000 --- a/.travis/docker.json +++ /dev/null @@ -1,4 +0,0 @@ -[ - {"host": "127.0.0.1", "port": 3306, "user": "root", "passwd": "", "db": "test1", "use_unicode": true, "local_infile": true}, - {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "db": "test2" } -] diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh deleted file mode 100755 index d9897e49c..000000000 --- a/.travis/initializedb.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash - -#debug -set -x -#verbose -set -v - -if [ ! -z "${DB}" ]; then - # disable existing database server in case of accidential connection - sudo service mysql stop - - docker pull ${DB} - docker run -it --name=mysqld -d -e MYSQL_ALLOW_EMPTY_PASSWORD=yes -p 3306:3306 ${DB} - sleep 10 - - mysql() { - docker exec mysqld mysql "${@}" - } - while : - do - sleep 5 - mysql -e 'select version()' - if [ $? = 0 ]; then - break - fi - echo "server logs" - docker logs --tail 5 mysqld - done - - mysql -e 'select VERSION()' - - if [ $DB == 'mysql:8.0' ]; then - WITH_PLUGIN='with mysql_native_password' - mysql -e 'SET GLOBAL local_infile=on' - docker cp mysqld:/var/lib/mysql/public_key.pem "${HOME}" - docker cp mysqld:/var/lib/mysql/ca.pem "${HOME}" - docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}" - docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}" - docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}" - - # Test user for auth test - mysql -e ' - CREATE USER - user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256", - nopass_sha256 IDENTIFIED WITH "sha256_password", - user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2", - nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password" - PASSWORD EXPIRE NEVER;' - mysql -e 'GRANT RELOAD ON *.* TO user_caching_sha2;' - else - WITH_PLUGIN='' - fi - - mysql -uroot -e 'create database test1 DEFAULT CHARACTER SET utf8mb4' - mysql -uroot -e 'create database test2 DEFAULT CHARACTER SET utf8mb4' - - mysql -u root -e "create user test2 identified ${WITH_PLUGIN} by 'some password'; grant all on test2.* to test2;" - mysql -u root -e "create user test2@localhost identified ${WITH_PLUGIN} by 'some password'; grant all on test2.* to test2@localhost;" - - cp .travis/docker.json pymysql/tests/databases.json -else - cat ~/.my.cnf - - mysql -e 'select VERSION()' - mysql -e 'create database test1 DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;' - mysql -e 'create database test2 DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;' - - mysql -u root -e "create user test2 identified by 'some password'; grant all on test2.* to test2;" - mysql -u root -e "create user test2@localhost identified by 'some password'; grant all on test2.* to test2@localhost;" - - cp .travis/database.json pymysql/tests/databases.json -fi diff --git a/CHANGELOG b/CHANGELOG.md similarity index 61% rename from CHANGELOG rename to CHANGELOG.md index b4372fed6..825dc47c1 100644 --- a/CHANGELOG +++ b/CHANGELOG.md @@ -1,5 +1,144 @@ # Changes +## Backward incompatible changes planned in the future. + +* Error classes in Cursor class will be removed after 2024-06 +* `Connection.set_charset(charset)` will be removed after 2024-06 +* `db` and `passwd` will emit DeprecationWarning in v1.2. See #933. +* `Connection.ping(reconnect)` change the default to not reconnect. + +## v1.1.1 + +Release date: 2024-05-21 + +> [!WARNING] +> This release fixes a vulnerability (CVE-2024-36039). +> All users are recommended to update to this version. +> +> If you can not update soon, check the input value from +> untrusted source has an expected type. Only dict input +> from untrusted source can be an attack vector. + +* Prohibit dict parameter for `Cursor.execute()`. It didn't produce valid SQL + and might cause SQL injection. (CVE-2024-36039) + +## v1.1.0 + +Release date: 2023-06-26 + +* Fixed SSCursor raising OperationalError for query timeouts on wrong statement (#1032) +* Exposed `Cursor.warning_count` to check for warnings without additional query (#1056) +* Make Cursor iterator (#995) +* Support '_' in key name in my.cnf (#1114) +* `Cursor.fetchall()` returns empty list instead of tuple (#1115). Note that `Cursor.fetchmany()` still return empty tuple after reading all rows for compatibility with Django. +* Deprecate Error classes in Cursor class (#1117) +* Add `Connection.set_character_set(charset, collation=None)`. This method is compatible with mysqlclient. (#1119) +* Deprecate `Connection.set_charset(charset)` (#1119) +* New connection always send "SET NAMES charset [COLLATE collation]" query. (#1119) + Since collation table is vary on MySQL server versions, collation in handshake is fragile. +* Support `charset="utf8mb3"` option (#1127) + + +## v1.0.3 + +Release date: 2023-03-28 + +* Dropped support of end of life MySQL version 5.6 +* Dropped support of end of life MariaDB versions below 10.3 +* Dropped support of end of life Python version 3.6 +* Removed `_last_executed` because of duplication with `_executed` by @rajat315315 in https://github.com/PyMySQL/PyMySQL/pull/948 +* Fix generating authentication response with long strings by @netch80 in https://github.com/PyMySQL/PyMySQL/pull/988 +* update pymysql.constants.CR by @Nothing4You in https://github.com/PyMySQL/PyMySQL/pull/1029 +* Document that the ssl connection parameter can be an SSLContext by @cakemanny in https://github.com/PyMySQL/PyMySQL/pull/1045 +* Raise ProgrammingError on -np.inf in addition to np.inf by @cdcadman in https://github.com/PyMySQL/PyMySQL/pull/1067 +* Use Python 3.11 release instead of -dev in tests by @Nothing4You in https://github.com/PyMySQL/PyMySQL/pull/1076 + + +## v1.0.2 + +Release date: 2021-01-09 + +* Fix `user`, `password`, `host`, `database` are still positional arguments. + All arguments of `connect()` are now keyword-only. (#941) + + +## v1.0.1 + +Release date: 2021-01-08 + +* Stop emitting DeprecationWarning for use of ``db`` and ``passwd``. + Note that they are still deprecated. (#939) +* Add ``python_requires=">=3.6"`` to setup.py. (#936) + + +## v1.0.0 + +Release date: 2021-01-07 + +Backward incompatible changes: + +* Python 2.7 and 3.5 are not supported. +* ``connect()`` uses keyword-only arguments. User must use keyword argument. +* ``connect()`` kwargs ``db`` and ``passwd`` are now deprecated; Use ``database`` and ``password`` instead. +* old_password authentication method (used by MySQL older than 4.1) is not supported. +* MySQL 5.5 and MariaDB 5.5 are not officially supported, although it may still works. +* Removed ``escape_dict``, ``escape_sequence``, and ``escape_string`` from ``pymysql`` + module. They are still in ``pymysql.converters``. + +Other changes: + +* Connection supports context manager API. ``__exit__`` closes the connection. (#886) +* Add MySQL Connector/Python compatible TLS options (#903) +* Major code cleanup; PyMySQL now uses black and flake8. + + +## v0.10.1 + +Release date: 2020-09-10 + +* Fix missing import of ProgrammingError. (#878) +* Fix auth switch request handling. (#890) + + +## v0.10.0 + +Release date: 2020-07-18 + +This version is the last version supporting Python 2.7. + +* MariaDB ed25519 auth is supported. +* Python 3.4 support is dropped. +* Context manager interface is removed from `Connection`. It will be added + with different meaning. +* MySQL warnings are not shown by default because many user report issue to + PyMySQL issue tracker when they see warning. You need to call "SHOW WARNINGS" + explicitly when you want to see warnings. +* Formatting of float object is changed from "3.14" to "3.14e0". +* Use cp1252 codec for latin1 charset. +* Fix decimal literal. +* TRUNCATED_WRONG_VALUE_FOR_FIELD, and ILLEGAL_VALUE_FOR_TYPE are now + DataError instead of InternalError. + + +## 0.9.3 + +Release date: 2018-12-18 + +* cryptography dependency is optional now. +* Fix old_password (used before MySQL 4.1) support. +* Deprecate old_password. +* Stop sending ``sys.argv[0]`` for connection attribute "program_name". +* Close connection when unknown error is happened. +* Deprecate context manager API of Connection object. + +## 0.9.2 + +Release date: 2018-07-04 + +* Disabled unintentinally enabled debug log +* Removed unintentionally installed tests + + ## 0.9.1 Release date: 2018-07-03 @@ -32,7 +171,7 @@ Release date: 2018-05-07 * Many test suite improvements, especially adding MySQL 8.0 and using Docker. Thanks to Daniel Black. -* Droppped support for old Python and MySQL whih is not tested long time. +* Dropped support for old Python and MySQL which is not tested long time. ## 0.8 @@ -110,7 +249,7 @@ Release date: 2016-08-30 Release date: 2016-07-29 * Fix SELECT JSON type cause UnicodeError -* Avoid float convertion while parsing microseconds +* Avoid float conversion while parsing microseconds * Warning has number * SSCursor supports warnings diff --git a/MANIFEST.in b/MANIFEST.in index 0a5207928..e2e577a9d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -include README.rst LICENSE CHANGELOG +include README.md LICENSE CHANGELOG.md diff --git a/Pipfile b/Pipfile deleted file mode 100644 index 0e142ba35..000000000 --- a/Pipfile +++ /dev/null @@ -1,12 +0,0 @@ -[[source]] -url = "https://pypi.python.org/simple" -verify_ssl = true -name = "pypi" - -[packages] -cryptography = "*" - -[dev-packages] -pytest = "*" -unittest2 = "*" -twine = "*" diff --git a/README.md b/README.md new file mode 100644 index 000000000..32f5df2f4 --- /dev/null +++ b/README.md @@ -0,0 +1,105 @@ +[![Documentation Status](https://readthedocs.org/projects/pymysql/badge/?version=latest)](https://pymysql.readthedocs.io/) +[![codecov](https://codecov.io/gh/PyMySQL/PyMySQL/branch/main/graph/badge.svg?token=ppEuaNXBW4)](https://codecov.io/gh/PyMySQL/PyMySQL) + +# PyMySQL + +This package contains a pure-Python MySQL client library, based on [PEP +249](https://www.python.org/dev/peps/pep-0249/). + +## Requirements + +- Python -- one of the following: + - [CPython](https://www.python.org/) : 3.7 and newer + - [PyPy](https://pypy.org/) : Latest 3.x version +- MySQL Server -- one of the following: + - [MySQL](https://www.mysql.com/) \>= 5.7 + - [MariaDB](https://mariadb.org/) \>= 10.4 + +## Installation + +Package is uploaded on [PyPI](https://pypi.org/project/PyMySQL). + +You can install it with pip: + + $ python3 -m pip install PyMySQL + +To use "sha256_password" or "caching_sha2_password" for authenticate, +you need to install additional dependency: + + $ python3 -m pip install PyMySQL[rsa] + +To use MariaDB's "ed25519" authentication method, you need to install +additional dependency: + + $ python3 -m pip install PyMySQL[ed25519] + +## Documentation + +Documentation is available online: + +For support, please refer to the +[StackOverflow](https://stackoverflow.com/questions/tagged/pymysql). + +## Example + +The following examples make use of a simple table + +``` sql +CREATE TABLE `users` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `email` varchar(255) COLLATE utf8_bin NOT NULL, + `password` varchar(255) COLLATE utf8_bin NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +AUTO_INCREMENT=1 ; +``` + +``` python +import pymysql.cursors + +# Connect to the database +connection = pymysql.connect(host='localhost', + user='user', + password='passwd', + database='db', + cursorclass=pymysql.cursors.DictCursor) + +with connection: + with connection.cursor() as cursor: + # Create a new record + sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)" + cursor.execute(sql, ('webmaster@python.org', 'very-secret')) + + # connection is not autocommit by default. So you must commit to save + # your changes. + connection.commit() + + with connection.cursor() as cursor: + # Read a single record + sql = "SELECT `id`, `password` FROM `users` WHERE `email`=%s" + cursor.execute(sql, ('webmaster@python.org',)) + result = cursor.fetchone() + print(result) +``` + +This example will print: + +``` python +{'password': 'very-secret', 'id': 1} +``` + +## Resources + +- DB-API 2.0: +- MySQL Reference Manuals: +- MySQL client/server protocol: + +- "Connector" channel in MySQL Community Slack: + +- PyMySQL mailing list: + + +## License + +PyMySQL is released under the MIT License. See LICENSE for more +information. diff --git a/README.rst b/README.rst deleted file mode 100644 index 1c7fba54c..000000000 --- a/README.rst +++ /dev/null @@ -1,145 +0,0 @@ -.. image:: https://readthedocs.org/projects/pymysql/badge/?version=latest - :target: https://pymysql.readthedocs.io/ - :alt: Documentation Status - -.. image:: https://badge.fury.io/py/PyMySQL.svg - :target: https://badge.fury.io/py/PyMySQL - -.. image:: https://travis-ci.org/PyMySQL/PyMySQL.svg?branch=master - :target: https://travis-ci.org/PyMySQL/PyMySQL - -.. image:: https://coveralls.io/repos/PyMySQL/PyMySQL/badge.svg?branch=master&service=github - :target: https://coveralls.io/github/PyMySQL/PyMySQL?branch=master - -.. image:: https://img.shields.io/badge/license-MIT-blue.svg - :target: https://github.com/PyMySQL/PyMySQL/blob/master/LICENSE - - -PyMySQL -======= - -.. contents:: Table of Contents - :local: - -This package contains a pure-Python MySQL client library, based on `PEP 249`_. - -Most public APIs are compatible with mysqlclient and MySQLdb. - -NOTE: PyMySQL doesn't support low level APIs `_mysql` provides like `data_seek`, -`store_result`, and `use_result`. You should use high level APIs defined in `PEP 249`_. -But some APIs like `autocommit` and `ping` are supported because `PEP 249`_ doesn't cover -their usecase. - -.. _`PEP 249`: https://www.python.org/dev/peps/pep-0249/ - - -Requirements -------------- - -* Python -- one of the following: - - - CPython_ : 2.7 and >= 3.4 - - PyPy_ : Latest version - -* MySQL Server -- one of the following: - - - MySQL_ >= 5.5 - - MariaDB_ >= 5.5 - -.. _CPython: https://www.python.org/ -.. _PyPy: https://pypy.org/ -.. _MySQL: https://www.mysql.com/ -.. _MariaDB: https://mariadb.org/ - - -Installation ------------- - -Package is uploaded on `PyPI `_. - -You can install it with pip:: - - $ pip3 install PyMySQL - - -Documentation -------------- - -Documentation is available online: https://pymysql.readthedocs.io/ - -For support, please refer to the `StackOverflow -`_. - -Example -------- - -The following examples make use of a simple table - -.. code:: sql - - CREATE TABLE `users` ( - `id` int(11) NOT NULL AUTO_INCREMENT, - `email` varchar(255) COLLATE utf8_bin NOT NULL, - `password` varchar(255) COLLATE utf8_bin NOT NULL, - PRIMARY KEY (`id`) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin - AUTO_INCREMENT=1 ; - - -.. code:: python - - import pymysql.cursors - - # Connect to the database - connection = pymysql.connect(host='localhost', - user='user', - password='passwd', - db='db', - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor) - - try: - with connection.cursor() as cursor: - # Create a new record - sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)" - cursor.execute(sql, ('webmaster@python.org', 'very-secret')) - - # connection is not autocommit by default. So you must commit to save - # your changes. - connection.commit() - - with connection.cursor() as cursor: - # Read a single record - sql = "SELECT `id`, `password` FROM `users` WHERE `email`=%s" - cursor.execute(sql, ('webmaster@python.org',)) - result = cursor.fetchone() - print(result) - finally: - connection.close() - -This example will print: - -.. code:: python - - {'password': 'very-secret', 'id': 1} - - -Resources ---------- - -* DB-API 2.0: http://www.python.org/dev/peps/pep-0249 - -* MySQL Reference Manuals: http://dev.mysql.com/doc/ - -* MySQL client/server protocol: - http://dev.mysql.com/doc/internals/en/client-server-protocol.html - -* "Connector" channel in MySQL Community Slack: - http://lefred.be/mysql-community-on-slack/ - -* PyMySQL mailing list: https://groups.google.com/forum/#!forum/pymysql-users - -License -------- - -PyMySQL is released under the MIT License. See LICENSE for more information. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..da9c516dd --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,5 @@ +## Security contact information + +To report a security vulnerability, please use the +[Tidelift security contact](https://tidelift.com/security). +Tidelift will coordinate the fix and disclosure. diff --git a/ci/database.json b/ci/database.json new file mode 100644 index 000000000..aad0bfb29 --- /dev/null +++ b/ci/database.json @@ -0,0 +1,4 @@ +[ + {"host": "localhost", "unix_socket": "/var/run/mysqld/mysqld.sock", "user": "root", "password": "", "database": "test1", "use_unicode": true, "local_infile": true}, + {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "database": "test2" } +] diff --git a/ci/docker-entrypoint-initdb.d/README b/ci/docker-entrypoint-initdb.d/README new file mode 100644 index 000000000..6a54b93da --- /dev/null +++ b/ci/docker-entrypoint-initdb.d/README @@ -0,0 +1,12 @@ +To test with a MariaDB or MySQL container image: + +docker run -d -p 3306:3306 -e MYSQL_ALLOW_EMPTY_PASSWORD=1 \ + --name=mysqld -v ./ci/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d:z \ + mysql:8.0.26 --local-infile=1 + +cp ci/docker.json pymysql/tests/databases.json + +pytest + + +Note: Some authentication tests that don't match the image version will fail. diff --git a/ci/docker-entrypoint-initdb.d/init.sql b/ci/docker-entrypoint-initdb.d/init.sql new file mode 100644 index 000000000..b741d41c5 --- /dev/null +++ b/ci/docker-entrypoint-initdb.d/init.sql @@ -0,0 +1,7 @@ +create database test1 DEFAULT CHARACTER SET utf8mb4; +create database test2 DEFAULT CHARACTER SET utf8mb4; +create user test2 identified by 'some password'; +grant all on test2.* to test2; +create user test2@localhost identified by 'some password'; +grant all on test2.* to test2@localhost; + diff --git a/ci/docker-entrypoint-initdb.d/mariadb.sql b/ci/docker-entrypoint-initdb.d/mariadb.sql new file mode 100644 index 000000000..912d365a9 --- /dev/null +++ b/ci/docker-entrypoint-initdb.d/mariadb.sql @@ -0,0 +1,2 @@ +/*M!100122 INSTALL SONAME "auth_ed25519" */; +/*M!100122 CREATE FUNCTION ed25519_password RETURNS STRING SONAME "auth_ed25519.so" */; diff --git a/ci/docker-entrypoint-initdb.d/mysql.sql b/ci/docker-entrypoint-initdb.d/mysql.sql new file mode 100644 index 000000000..a4ba0927d --- /dev/null +++ b/ci/docker-entrypoint-initdb.d/mysql.sql @@ -0,0 +1,8 @@ +/*!80001 CREATE USER + user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256_01234567890123456789", + nopass_sha256 IDENTIFIED WITH "sha256_password", + user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2_01234567890123456789", + nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password" + PASSWORD EXPIRE NEVER */; + +/*!80001 GRANT RELOAD ON *.* TO user_caching_sha2 */; diff --git a/ci/docker.json b/ci/docker.json new file mode 100644 index 000000000..63d19a687 --- /dev/null +++ b/ci/docker.json @@ -0,0 +1,5 @@ +[ + {"host": "127.0.0.1", "port": 3306, "user": "root", "password": "", "database": "test1", "use_unicode": true, "local_infile": true}, + {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "database": "test2" }, + {"host": "localhost", "port": 3306, "user": "test2", "password": "some password", "database": "test2", "unix_socket": "/run/mysqld/mysqld.sock"} +] diff --git a/ci/test_mysql.py b/ci/test_mysql.py new file mode 100644 index 000000000..b97978a27 --- /dev/null +++ b/ci/test_mysql.py @@ -0,0 +1,47 @@ +# This is an example test settings file for use with the Django test suite. +# +# The 'sqlite3' backend requires only the ENGINE setting (an in- +# memory database will be used). All other backends will require a +# NAME and potentially authentication information. See the +# following section in the docs for more information: +# +# https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/unit-tests/ +# +# The different databases that Django supports behave differently in certain +# situations, so it is recommended to run the test suite against as many +# database backends as possible. You may want to create a separate settings +# file for each of the backends you test against. + +import pymysql + +pymysql.install_as_MySQLdb() + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.mysql", + "NAME": "django_default", + "HOST": "127.0.0.1", + "USER": "scott", + "PASSWORD": "tiger", + "TEST": {"CHARSET": "utf8mb3", "COLLATION": "utf8mb3_general_ci"}, + }, + "other": { + "ENGINE": "django.db.backends.mysql", + "NAME": "django_other", + "HOST": "127.0.0.1", + "USER": "scott", + "PASSWORD": "tiger", + "TEST": {"CHARSET": "utf8mb3", "COLLATION": "utf8mb3_general_ci"}, + }, +} + +SECRET_KEY = "django_tests_secret_key" + +# Use a fast hasher to speed up tests. +PASSWORD_HASHERS = [ + "django.contrib.auth.hashers.MD5PasswordHasher", +] + +DEFAULT_AUTO_FIELD = "django.db.models.AutoField" + +USE_TZ = False diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..919adf200 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,7 @@ +# https://docs.codecov.com/docs/common-recipe-list +coverage: + status: + project: + default: + target: auto + threshold: 3% diff --git a/docs/Makefile b/docs/Makefile index d37255520..c1240d2ba 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -74,30 +74,6 @@ json: @echo @echo "Build finished; now you can process the JSON files." -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PyMySQL.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PyMySQL.qhc" - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/PyMySQL" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PyMySQL" - @echo "# devhelp" - epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index dcd4287c6..000000000 --- a/docs/make.bat +++ /dev/null @@ -1,242 +0,0 @@ -@ECHO OFF - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% source -set I18NSPHINXOPTS=%SPHINXOPTS% source -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\PyMySQL.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\PyMySQL.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..014066235 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +sphinx~=7.2 +sphinx-rtd-theme~=2.0.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index bbadcbed1..158d0d12f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # PyMySQL documentation build configuration file, created by # sphinx-quickstart on Tue May 17 12:01:11 2016. # @@ -18,55 +16,54 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../../")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', + "sphinx.ext.autodoc", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'PyMySQL' -copyright = u'2016, Yutaka Matsubara and GitHub contributors' +project = "PyMySQL" +copyright = "2023, Inada Naoki and GitHub contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '0.7' +version = "0.7" # The full version, including alpha/beta/rc tags. -release = '0.7.2' +release = "0.7.2" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -74,154 +71,157 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'PyMySQLdoc' +htmlhelp_basename = "PyMySQLdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'PyMySQL.tex', u'PyMySQL Documentation', - u'Yutaka Matsubara and GitHub contributors', 'manual'), + ( + "index", + "PyMySQL.tex", + "PyMySQL Documentation", + "Yutaka Matsubara and GitHub contributors", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- @@ -229,12 +229,17 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'pymysql', u'PyMySQL Documentation', - [u'Yutaka Matsubara and GitHub contributors'], 1) + ( + "index", + "pymysql", + "PyMySQL Documentation", + ["Yutaka Matsubara and GitHub contributors"], + 1, + ) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -243,23 +248,29 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'PyMySQL', u'PyMySQL Documentation', - u'Yutaka Matsubara and GitHub contributors', 'PyMySQL', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "PyMySQL", + "PyMySQL Documentation", + "Yutaka Matsubara and GitHub contributors", + "PyMySQL", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/': None} +intersphinx_mapping = {"http://docs.python.org/": None} diff --git a/docs/source/index.rst b/docs/source/index.rst index 97633f1aa..e64b64238 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,5 +1,5 @@ -Welcome to PyMySQL's documentation! -=================================== +PyMySQL documentation +===================== .. toctree:: :maxdepth: 2 diff --git a/docs/source/user/development.rst b/docs/source/user/development.rst index 39c40e1a7..1f8a2637f 100644 --- a/docs/source/user/development.rst +++ b/docs/source/user/development.rst @@ -22,17 +22,13 @@ If you would like to run the test suite, create a database for testing like this mysql -e 'create database test_pymysql DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;' mysql -e 'create database test_pymysql2 DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;' -Then, copy the file ``.travis/database.json`` to ``pymysql/tests/databases.json`` +Then, copy the file ``ci/database.json`` to ``pymysql/tests/databases.json`` and edit the new file to match your MySQL configuration:: - $ cp .travis/database.json pymysql/tests/databases.json + $ cp ci/database.json pymysql/tests/databases.json $ $EDITOR pymysql/tests/databases.json To run all the tests, execute the script ``runtests.py``:: - $ python runtests.py - -A ``tox.ini`` file is also provided for conveniently running tests on multiple -Python versions:: - - $ tox + $ pip install pytest + $ pytest -v pymysql diff --git a/docs/source/user/examples.rst b/docs/source/user/examples.rst index 87af40c37..3946db9b9 100644 --- a/docs/source/user/examples.rst +++ b/docs/source/user/examples.rst @@ -18,7 +18,7 @@ The following examples make use of a simple table `email` varchar(255) COLLATE utf8_bin NOT NULL, `password` varchar(255) COLLATE utf8_bin NOT NULL, PRIMARY KEY (`id`) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin AUTO_INCREMENT=1 ; @@ -30,11 +30,11 @@ The following examples make use of a simple table connection = pymysql.connect(host='localhost', user='user', password='passwd', - db='db', + database='db', charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor) - try: + with connection: with connection.cursor() as cursor: # Create a new record sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)" @@ -50,11 +50,10 @@ The following examples make use of a simple table cursor.execute(sql, ('webmaster@python.org',)) result = cursor.fetchone() print(result) - finally: - connection.close() + This example will print: .. code:: python - {'password': 'very-secret', 'id': 1} + {'id': 1, 'password': 'very-secret'} diff --git a/docs/source/user/installation.rst b/docs/source/user/installation.rst index e3bfe84d0..9313f14d3 100644 --- a/docs/source/user/installation.rst +++ b/docs/source/user/installation.rst @@ -6,24 +6,27 @@ Installation The last stable release is available on PyPI and can be installed with ``pip``:: - $ pip install PyMySQL + $ python3 -m pip install PyMySQL + +To use "sha256_password" or "caching_sha2_password" for authenticate, +you need to install additional dependency:: + + $ python3 -m pip install PyMySQL[rsa] Requirements ------------- * Python -- one of the following: - - CPython_ >= 2.6 or >= 3.3 - - PyPy_ >= 4.0 - - IronPython_ 2.7 + - CPython_ >= 3.7 + - Latest PyPy_ 3 * MySQL Server -- one of the following: - - MySQL_ >= 4.1 (tested with only 5.5~) - - MariaDB_ >= 5.1 + - MySQL_ >= 5.7 + - MariaDB_ >= 10.3 .. _CPython: http://www.python.org/ .. _PyPy: http://pypy.org/ -.. _IronPython: http://ironpython.net/ .. _MySQL: http://www.mysql.com/ .. _MariaDB: https://mariadb.org/ diff --git a/example.py b/example.py index 68582138d..c12f103b5 100644 --- a/example.py +++ b/example.py @@ -1,16 +1,13 @@ #!/usr/bin/env python -from __future__ import print_function - import pymysql -conn = pymysql.connect(host='localhost', port=3306, user='root', passwd='', db='mysql') +conn = pymysql.connect(host="localhost", port=3306, user="root", passwd="", db="mysql") cur = conn.cursor() cur.execute("SELECT Host,User FROM user") print(cur.description) - print() for row in cur: diff --git a/pymysql/__init__.py b/pymysql/__init__.py index b79b4b83e..bbf9023ef 100644 --- a/pymysql/__init__.py +++ b/pymysql/__init__.py @@ -21,32 +21,65 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + import sys -from ._compat import PY2 from .constants import FIELD_TYPE -from .converters import escape_dict, escape_sequence, escape_string from .err import ( - Warning, Error, InterfaceError, DataError, - DatabaseError, OperationalError, IntegrityError, InternalError, - NotSupportedError, ProgrammingError, MySQLError) + Warning, + Error, + InterfaceError, + DataError, + DatabaseError, + OperationalError, + IntegrityError, + InternalError, + NotSupportedError, + ProgrammingError, + MySQLError, +) from .times import ( - Date, Time, Timestamp, - DateFromTicks, TimeFromTicks, TimestampFromTicks) + Date, + Time, + Timestamp, + DateFromTicks, + TimeFromTicks, + TimestampFromTicks, +) + +# PyMySQL version. +# Used by setuptools and connection_attrs +VERSION = (1, 1, 1, "final", 1) +VERSION_STRING = "1.1.1" + +### for mysqlclient compatibility +### Django checks mysqlclient version. +version_info = (1, 4, 6, "final", 1) +__version__ = "1.4.6" + + +def get_client_info(): # for MySQLdb compatibility + return __version__ + + +def install_as_MySQLdb(): + """ + After this function is called, any application that imports MySQLdb + will unwittingly actually use pymysql. + """ + sys.modules["MySQLdb"] = sys.modules["pymysql"] + +# end of mysqlclient compatibility code -VERSION = (0, 9, 2, None) -if VERSION[3] is not None: - VERSION_STRING = "%d.%d.%d_%s" % VERSION -else: - VERSION_STRING = "%d.%d.%d" % VERSION[:3] threadsafety = 1 apilevel = "2.0" paramstyle = "pyformat" +from . import connections # noqa: E402 -class DBAPISet(frozenset): +class DBAPISet(frozenset): def __ne__(self, other): if isinstance(other, set): return frozenset.__ne__(self, other) @@ -63,79 +96,88 @@ def __hash__(self): return frozenset.__hash__(self) -STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, - FIELD_TYPE.VAR_STRING]) -BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, - FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB]) -NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT, - FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, - FIELD_TYPE.TINY, FIELD_TYPE.YEAR]) -DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) -TIME = DBAPISet([FIELD_TYPE.TIME]) +STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]) +BINARY = DBAPISet( + [ + FIELD_TYPE.BLOB, + FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.MEDIUM_BLOB, + FIELD_TYPE.TINY_BLOB, + ] +) +NUMBER = DBAPISet( + [ + FIELD_TYPE.DECIMAL, + FIELD_TYPE.DOUBLE, + FIELD_TYPE.FLOAT, + FIELD_TYPE.INT24, + FIELD_TYPE.LONG, + FIELD_TYPE.LONGLONG, + FIELD_TYPE.TINY, + FIELD_TYPE.YEAR, + ] +) +DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) +TIME = DBAPISet([FIELD_TYPE.TIME]) TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME]) -DATETIME = TIMESTAMP -ROWID = DBAPISet() +DATETIME = TIMESTAMP +ROWID = DBAPISet() def Binary(x): """Return x as a binary type.""" - if PY2: - return bytearray(x) - else: - return bytes(x) - - -def Connect(*args, **kwargs): - """ - Connect to the database; see connections.Connection.__init__() for - more information. - """ - from .connections import Connection - return Connection(*args, **kwargs) - -from . import connections as _orig_conn -if _orig_conn.Connection.__init__.__doc__ is not None: - Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ -del _orig_conn - - -def get_client_info(): # for MySQLdb compatibility - version = VERSION - if VERSION[3] is None: - version = VERSION[:3] - return '.'.join(map(str, version)) - -connect = Connection = Connect - -# we include a doctored version_info here for MySQLdb compatibility -version_info = (1, 3, 12, "final", 0) + return bytes(x) -NULL = "NULL" - -__version__ = get_client_info() def thread_safe(): return True # match MySQLdb.thread_safe() -def install_as_MySQLdb(): - """ - After this function is called, any application that imports MySQLdb or - _mysql will unwittingly actually use pymysql. - """ - sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"] + +Connect = connect = Connection = connections.Connection +NULL = "NULL" __all__ = [ - 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date', - 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', - 'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError', - 'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER', - 'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError', - 'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect', - 'connections', 'constants', 'converters', 'cursors', - 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info', - 'paramstyle', 'threadsafety', 'version_info', - + "BINARY", + "Binary", + "Connect", + "Connection", + "DATE", + "Date", + "Time", + "Timestamp", + "DateFromTicks", + "TimeFromTicks", + "TimestampFromTicks", + "DataError", + "DatabaseError", + "Error", + "FIELD_TYPE", + "IntegrityError", + "InterfaceError", + "InternalError", + "MySQLError", + "NULL", + "NUMBER", + "NotSupportedError", + "DBAPISet", + "OperationalError", + "ProgrammingError", + "ROWID", + "STRING", + "TIME", + "TIMESTAMP", + "Warning", + "apilevel", + "connect", + "connections", + "constants", + "converters", + "cursors", + "get_client_info", + "paramstyle", + "threadsafety", + "version_info", "install_as_MySQLdb", - "NULL", "__version__", + "__version__", ] diff --git a/pymysql/_auth.py b/pymysql/_auth.py index bbb742d3a..8ce744fb5 100644 --- a/pymysql/_auth.py +++ b/pymysql/_auth.py @@ -1,22 +1,26 @@ """ Implements auth methods """ -from ._compat import text_type, PY2 -from .constants import CLIENT + from .err import OperationalError -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization, hashes -from cryptography.hazmat.primitives.asymmetric import padding + +try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization, hashes + from cryptography.hazmat.primitives.asymmetric import padding + + _have_cryptography = True +except ImportError: + _have_cryptography = False from functools import partial import hashlib -import struct DEBUG = False SCRAMBLE_LENGTH = 20 -sha1_new = partial(hashlib.new, 'sha1') +sha1_new = partial(hashlib.new, "sha1") # mysql_native_password @@ -26,7 +30,7 @@ def scramble_native_password(password, message): """Scramble used for mysql_native_password""" if not password: - return b'' + return b"" stage1 = sha1_new(password).digest() stage2 = sha1_new(stage1).digest() @@ -39,8 +43,6 @@ def scramble_native_password(password, message): def _my_crypt(message1, message2): result = bytearray(message1) - if PY2: - message2 = bytearray(message2) for i in range(len(result)): result[i] ^= message2[i] @@ -48,60 +50,67 @@ def _my_crypt(message1, message2): return bytes(result) -# old_passwords support ported from libmysql/password.c -# https://dev.mysql.com/doc/internals/en/old-password-authentication.html +# MariaDB's client_ed25519-plugin +# https://mariadb.com/kb/en/library/connection/#client_ed25519-plugin -SCRAMBLE_LENGTH_323 = 8 +_nacl_bindings = False -class RandStruct_323(object): +def _init_nacl(): + global _nacl_bindings + try: + from nacl import bindings - def __init__(self, seed1, seed2): - self.max_value = 0x3FFFFFFF - self.seed1 = seed1 % self.max_value - self.seed2 = seed2 % self.max_value + _nacl_bindings = bindings + except ImportError: + raise RuntimeError( + "'pynacl' package is required for ed25519_password auth method" + ) - def my_rnd(self): - self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value - self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value - return float(self.seed1) / float(self.max_value) +def _scalar_clamp(s32): + ba = bytearray(s32) + ba0 = bytes(bytearray([ba[0] & 248])) + ba31 = bytes(bytearray([(ba[31] & 127) | 64])) + return ba0 + bytes(s32[1:31]) + ba31 -def scramble_old_password(password, message): - """Scramble for old_password""" - hash_pass = _hash_password_323(password) - hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) - hash_pass_n = struct.unpack(">LL", hash_pass) - hash_message_n = struct.unpack(">LL", hash_message) - rand_st = RandStruct_323( - hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1] - ) - outbuf = io.BytesIO() - for _ in range(min(SCRAMBLE_LENGTH_323, len(message))): - outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) - extra = int2byte(int(rand_st.my_rnd() * 31)) - out = outbuf.getvalue() - outbuf = io.BytesIO() - for c in out: - outbuf.write(int2byte(byte2int(c) ^ byte2int(extra))) - return outbuf.getvalue() +def ed25519_password(password, scramble): + """Sign a random scramble with elliptic curve Ed25519. + + Secret and public key are derived from password. + """ + # variable names based on rfc8032 section-5.1.6 + # + if not _nacl_bindings: + _init_nacl() + + # h = SHA512(password) + h = hashlib.sha512(password).digest() + + # s = prune(first_half(h)) + s = _scalar_clamp(h[:32]) + # r = SHA512(second_half(h) || M) + r = hashlib.sha512(h[32:] + scramble).digest() -def _hash_password_323(password): - nr = 1345345333 - add = 7 - nr2 = 0x12345671 + # R = encoded point [r]B + r = _nacl_bindings.crypto_core_ed25519_scalar_reduce(r) + R = _nacl_bindings.crypto_scalarmult_ed25519_base_noclamp(r) - # x in py3 is numbers, p27 is chars - for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]: - nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF - nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF - add = (add + c) & 0xFFFFFFFF + # A = encoded point [s]B + A = _nacl_bindings.crypto_scalarmult_ed25519_base_noclamp(s) - r1 = nr & ((1 << 31) - 1) # kill sign bits - r2 = nr2 & ((1 << 31) - 1) - return struct.pack(">LL", r1, r2) + # k = SHA512(R || A || M) + k = hashlib.sha512(R + A + scramble).digest() + + # S = (k * s + r) mod L + k = _nacl_bindings.crypto_core_ed25519_scalar_reduce(k) + ks = _nacl_bindings.crypto_core_ed25519_scalar_mul(k, s) + S = _nacl_bindings.crypto_core_ed25519_scalar_add(ks, r) + + # signature = R || S + return R + S # sha256_password @@ -115,8 +124,11 @@ def _roundtrip(conn, send_data): def _xor_password(password, salt): + # Trailing NUL character will be added in Auth Switch Request. + # See https://github.com/mysql/mysql-server/blob/7d10c82196c8e45554f27c00681474a9fb86d137/sql/auth/sha2_password.cc#L939-L945 + salt = salt[:SCRAMBLE_LENGTH] password_bytes = bytearray(password) - salt = bytearray(salt) # for PY2 compat. + # salt = bytearray(salt) # for PY2 compat. salt_len = len(salt) for i in range(len(password_bytes)): password_bytes[i] ^= salt[i % salt_len] @@ -128,7 +140,12 @@ def sha2_rsa_encrypt(password, salt, public_key): Used for sha256_password and caching_sha2_password. """ - message = _xor_password(password + b'\0', salt) + if not _have_cryptography: + raise RuntimeError( + "'cryptography' package is required for sha256_password or" + + " caching_sha2_password auth methods" + ) + message = _xor_password(password + b"\0", salt) rsa_key = serialization.load_pem_public_key(public_key, default_backend()) return rsa_key.encrypt( message, @@ -144,7 +161,7 @@ def sha256_password_auth(conn, pkt): if conn._secure: if DEBUG: print("sha256: Sending plain password") - data = conn.password + b'\0' + data = conn.password + b"\0" return _roundtrip(conn, data) if pkt.is_auth_switch_request(): @@ -153,12 +170,12 @@ def sha256_password_auth(conn, pkt): # Request server public key if DEBUG: print("sha256: Requesting server public key") - pkt = _roundtrip(conn, b'\1') + pkt = _roundtrip(conn, b"\1") if pkt.is_extra_auth_data(): conn.server_public_key = pkt._data[1:] if DEBUG: - print("Received public key:\n", conn.server_public_key.decode('ascii')) + print("Received public key:\n", conn.server_public_key.decode("ascii")) if conn.password: if not conn.server_public_key: @@ -166,7 +183,7 @@ def sha256_password_auth(conn, pkt): data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) else: - data = b'' + data = b"" return _roundtrip(conn, data) @@ -178,15 +195,13 @@ def scramble_caching_sha2(password, nonce): XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce)) """ if not password: - return b'' + return b"" p1 = hashlib.sha256(password).digest() p2 = hashlib.sha256(p1).digest() p3 = hashlib.sha256(p2 + nonce).digest() res = bytearray(p1) - if PY2: - p3 = bytearray(p3) for i in range(len(p3)): res[i] ^= p3[i] @@ -196,7 +211,7 @@ def scramble_caching_sha2(password, nonce): def caching_sha2_password_auth(conn, pkt): # No password fast path if not conn.password: - return _roundtrip(conn, b'') + return _roundtrip(conn, b"") if pkt.is_auth_switch_request(): # Try from fast auth @@ -228,7 +243,7 @@ def caching_sha2_password_auth(conn, pkt): return pkt if n != 4: - raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n) + raise OperationalError("caching sha2: Unknown result for fast auth: %s" % n) if DEBUG: print("caching sha2: Trying full auth...") @@ -236,10 +251,10 @@ def caching_sha2_password_auth(conn, pkt): if conn._secure: if DEBUG: print("caching sha2: Sending plain password via secure connection") - return _roundtrip(conn, conn.password + b'\0') + return _roundtrip(conn, conn.password + b"\0") if not conn.server_public_key: - pkt = _roundtrip(conn, b'\x02') # Request public key + pkt = _roundtrip(conn, b"\x02") # Request public key if not pkt.is_extra_auth_data(): raise OperationalError( "caching sha2: Unknown packet for public key: %s" % pkt._data[:1] @@ -247,7 +262,7 @@ def caching_sha2_password_auth(conn, pkt): conn.server_public_key = pkt._data[1:] if DEBUG: - print(conn.server_public_key.decode('ascii')) + print(conn.server_public_key.decode("ascii")) data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key) pkt = _roundtrip(conn, data) diff --git a/pymysql/_compat.py b/pymysql/_compat.py deleted file mode 100644 index 252789ec4..000000000 --- a/pymysql/_compat.py +++ /dev/null @@ -1,21 +0,0 @@ -import sys - -PY2 = sys.version_info[0] == 2 -PYPY = hasattr(sys, 'pypy_translation_info') -JYTHON = sys.platform.startswith('java') -IRONPYTHON = sys.platform == 'cli' -CPYTHON = not PYPY and not JYTHON and not IRONPYTHON - -if PY2: - import __builtin__ - range_type = xrange - text_type = unicode - long_type = long - str_type = basestring - unichr = __builtin__.unichr -else: - range_type = range - text_type = str - long_type = int - str_type = str - unichr = chr diff --git a/pymysql/_socketio.py b/pymysql/_socketio.py deleted file mode 100644 index 6a11d42e4..000000000 --- a/pymysql/_socketio.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -SocketIO imported from socket module in Python 3. - -Copyright (c) 2001-2013 Python Software Foundation; All Rights Reserved. -""" - -from socket import * -import io -import errno - -__all__ = ['SocketIO'] - -EINTR = errno.EINTR -_blocking_errnos = (errno.EAGAIN, errno.EWOULDBLOCK) - -class SocketIO(io.RawIOBase): - - """Raw I/O implementation for stream sockets. - - This class supports the makefile() method on sockets. It provides - the raw I/O interface on top of a socket object. - """ - - # One might wonder why not let FileIO do the job instead. There are two - # main reasons why FileIO is not adapted: - # - it wouldn't work under Windows (where you can't used read() and - # write() on a socket handle) - # - it wouldn't work with socket timeouts (FileIO would ignore the - # timeout and consider the socket non-blocking) - - # XXX More docs - - def __init__(self, sock, mode): - if mode not in ("r", "w", "rw", "rb", "wb", "rwb"): - raise ValueError("invalid mode: %r" % mode) - io.RawIOBase.__init__(self) - self._sock = sock - if "b" not in mode: - mode += "b" - self._mode = mode - self._reading = "r" in mode - self._writing = "w" in mode - self._timeout_occurred = False - - def readinto(self, b): - """Read up to len(b) bytes into the writable buffer *b* and return - the number of bytes read. If the socket is non-blocking and no bytes - are available, None is returned. - - If *b* is non-empty, a 0 return value indicates that the connection - was shutdown at the other end. - """ - self._checkClosed() - self._checkReadable() - if self._timeout_occurred: - raise IOError("cannot read from timed out object") - while True: - try: - return self._sock.recv_into(b) - except timeout: - self._timeout_occurred = True - raise - except error as e: - n = e.args[0] - if n == EINTR: - continue - if n in _blocking_errnos: - return None - raise - - def write(self, b): - """Write the given bytes or bytearray object *b* to the socket - and return the number of bytes written. This can be less than - len(b) if not all data could be written. If the socket is - non-blocking and no bytes could be written None is returned. - """ - self._checkClosed() - self._checkWritable() - try: - return self._sock.send(b) - except error as e: - # XXX what about EINTR? - if e.args[0] in _blocking_errnos: - return None - raise - - def readable(self): - """True if the SocketIO is open for reading. - """ - if self.closed: - raise ValueError("I/O operation on closed socket.") - return self._reading - - def writable(self): - """True if the SocketIO is open for writing. - """ - if self.closed: - raise ValueError("I/O operation on closed socket.") - return self._writing - - def seekable(self): - """True if the SocketIO is open for seeking. - """ - if self.closed: - raise ValueError("I/O operation on closed socket.") - return super().seekable() - - def fileno(self): - """Return the file descriptor of the underlying socket. - """ - self._checkClosed() - return self._sock.fileno() - - @property - def name(self): - if not self.closed: - return self.fileno() - else: - return -1 - - @property - def mode(self): - return self._mode - - def close(self): - """Close the SocketIO object. This doesn't close the underlying - socket, except if all references to it have disappeared. - """ - if self.closed: - return - io.RawIOBase.close(self) - self._sock._decref_socketios() - self._sock = None - diff --git a/pymysql/charset.py b/pymysql/charset.py index 968376cfa..b1c1ca8b8 100644 --- a/pymysql/charset.py +++ b/pymysql/charset.py @@ -1,25 +1,29 @@ -MBLENGTH = { - 8:1, - 33:3, - 88:2, - 91:2 - } +# Internal use only. Do not use directly. +MBLENGTH = {8: 1, 33: 3, 88: 2, 91: 2} -class Charset(object): - def __init__(self, id, name, collation, is_default): + +class Charset: + def __init__(self, id, name, collation, is_default=False): self.id, self.name, self.collation = id, name, collation - self.is_default = is_default == 'Yes' + self.is_default = is_default def __repr__(self): - return "Charset(id=%s, name=%r, collation=%r)" % ( - self.id, self.name, self.collation) + return ( + f"Charset(id={self.id}, name={self.name!r}, collation={self.collation!r})" + ) @property def encoding(self): name = self.name - if name == 'utf8mb4': - return 'utf8' + if name in ("utf8mb4", "utf8mb3"): + return "utf8" + if name == "latin1": + return "cp1252" + if name == "koi8r": + return "koi8_r" + if name == "koi8u": + return "koi8_u" return name @property @@ -30,241 +34,183 @@ def is_binary(self): class Charsets: def __init__(self): self._by_id = {} + self._by_name = {} def add(self, c): self._by_id[c.id] = c + if c.is_default: + self._by_name[c.name] = c def by_id(self, id): return self._by_id[id] def by_name(self, name): - name = name.lower() - for c in self._by_id.values(): - if c.name == name and c.is_default: - return c + if name == "utf8": + name = "utf8mb4" + return self._by_name.get(name.lower()) + _charsets = Charsets() +charset_by_name = _charsets.by_name +charset_by_id = _charsets.by_id + """ +TODO: update this script. + Generated with: mysql -N -s -e "select id, character_set_name, collation_name, is_default from information_schema.collations order by id;" | python -c "import sys for l in sys.stdin.readlines(): - id, name, collation, is_default = l.split(chr(9)) - print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \ - % (id, name, collation, is_default.strip()) -" - + id, name, collation, is_default = l.split(chr(9)) + if is_default.strip() == "Yes": + print('_charsets.add(Charset(%s, \'%s\', \'%s\', True))' \ + % (id, name, collation)) + else: + print('_charsets.add(Charset(%s, \'%s\', \'%s\'))' \ + % (id, name, collation, bool(is_default.strip())) """ -_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes')) -_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', '')) -_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes')) -_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes')) -_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', '')) -_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes')) -_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes')) -_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes')) -_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes')) -_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes')) -_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes')) -_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes')) -_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes')) -_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', '')) -_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', '')) -_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes')) -_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes')) -_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes')) -_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', '')) -_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', '')) -_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes')) -_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', '')) -_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes')) -_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes')) -_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes')) -_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', '')) -_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes')) -_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', '')) -_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes')) -_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', '')) -_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes')) -_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes')) -_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', '')) -_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes')) -_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes')) -_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes')) -_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes')) -_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes')) -_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes')) -_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes')) -_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', '')) -_charsets.add(Charset(43, 'macce', 'macce_bin', '')) -_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', '')) -_charsets.add(Charset(45, 'utf8mb4', 'utf8mb4_general_ci', 'Yes')) -_charsets.add(Charset(46, 'utf8mb4', 'utf8mb4_bin', '')) -_charsets.add(Charset(47, 'latin1', 'latin1_bin', '')) -_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', '')) -_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', '')) -_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', '')) -_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes')) -_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', '')) -_charsets.add(Charset(53, 'macroman', 'macroman_bin', '')) -_charsets.add(Charset(54, 'utf16', 'utf16_general_ci', 'Yes')) -_charsets.add(Charset(55, 'utf16', 'utf16_bin', '')) -_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes')) -_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', '')) -_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes')) -_charsets.add(Charset(60, 'utf32', 'utf32_general_ci', 'Yes')) -_charsets.add(Charset(61, 'utf32', 'utf32_bin', '')) -_charsets.add(Charset(63, 'binary', 'binary', 'Yes')) -_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', '')) -_charsets.add(Charset(65, 'ascii', 'ascii_bin', '')) -_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', '')) -_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', '')) -_charsets.add(Charset(68, 'cp866', 'cp866_bin', '')) -_charsets.add(Charset(69, 'dec8', 'dec8_bin', '')) -_charsets.add(Charset(70, 'greek', 'greek_bin', '')) -_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', '')) -_charsets.add(Charset(72, 'hp8', 'hp8_bin', '')) -_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', '')) -_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', '')) -_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', '')) -_charsets.add(Charset(77, 'latin2', 'latin2_bin', '')) -_charsets.add(Charset(78, 'latin5', 'latin5_bin', '')) -_charsets.add(Charset(79, 'latin7', 'latin7_bin', '')) -_charsets.add(Charset(80, 'cp850', 'cp850_bin', '')) -_charsets.add(Charset(81, 'cp852', 'cp852_bin', '')) -_charsets.add(Charset(82, 'swe7', 'swe7_bin', '')) -_charsets.add(Charset(83, 'utf8', 'utf8_bin', '')) -_charsets.add(Charset(84, 'big5', 'big5_bin', '')) -_charsets.add(Charset(85, 'euckr', 'euckr_bin', '')) -_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', '')) -_charsets.add(Charset(87, 'gbk', 'gbk_bin', '')) -_charsets.add(Charset(88, 'sjis', 'sjis_bin', '')) -_charsets.add(Charset(89, 'tis620', 'tis620_bin', '')) -_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', '')) -_charsets.add(Charset(91, 'ujis', 'ujis_bin', '')) -_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes')) -_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', '')) -_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', '')) -_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes')) -_charsets.add(Charset(96, 'cp932', 'cp932_bin', '')) -_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes')) -_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', '')) -_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', '')) -_charsets.add(Charset(101, 'utf16', 'utf16_unicode_ci', '')) -_charsets.add(Charset(102, 'utf16', 'utf16_icelandic_ci', '')) -_charsets.add(Charset(103, 'utf16', 'utf16_latvian_ci', '')) -_charsets.add(Charset(104, 'utf16', 'utf16_romanian_ci', '')) -_charsets.add(Charset(105, 'utf16', 'utf16_slovenian_ci', '')) -_charsets.add(Charset(106, 'utf16', 'utf16_polish_ci', '')) -_charsets.add(Charset(107, 'utf16', 'utf16_estonian_ci', '')) -_charsets.add(Charset(108, 'utf16', 'utf16_spanish_ci', '')) -_charsets.add(Charset(109, 'utf16', 'utf16_swedish_ci', '')) -_charsets.add(Charset(110, 'utf16', 'utf16_turkish_ci', '')) -_charsets.add(Charset(111, 'utf16', 'utf16_czech_ci', '')) -_charsets.add(Charset(112, 'utf16', 'utf16_danish_ci', '')) -_charsets.add(Charset(113, 'utf16', 'utf16_lithuanian_ci', '')) -_charsets.add(Charset(114, 'utf16', 'utf16_slovak_ci', '')) -_charsets.add(Charset(115, 'utf16', 'utf16_spanish2_ci', '')) -_charsets.add(Charset(116, 'utf16', 'utf16_roman_ci', '')) -_charsets.add(Charset(117, 'utf16', 'utf16_persian_ci', '')) -_charsets.add(Charset(118, 'utf16', 'utf16_esperanto_ci', '')) -_charsets.add(Charset(119, 'utf16', 'utf16_hungarian_ci', '')) -_charsets.add(Charset(120, 'utf16', 'utf16_sinhala_ci', '')) -_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', '')) -_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', '')) -_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', '')) -_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', '')) -_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', '')) -_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', '')) -_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', '')) -_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', '')) -_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', '')) -_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', '')) -_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', '')) -_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', '')) -_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', '')) -_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', '')) -_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', '')) -_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', '')) -_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', '')) -_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', '')) -_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', '')) -_charsets.add(Charset(147, 'ucs2', 'ucs2_sinhala_ci', '')) -_charsets.add(Charset(159, 'ucs2', 'ucs2_general_mysql500_ci', '')) -_charsets.add(Charset(160, 'utf32', 'utf32_unicode_ci', '')) -_charsets.add(Charset(161, 'utf32', 'utf32_icelandic_ci', '')) -_charsets.add(Charset(162, 'utf32', 'utf32_latvian_ci', '')) -_charsets.add(Charset(163, 'utf32', 'utf32_romanian_ci', '')) -_charsets.add(Charset(164, 'utf32', 'utf32_slovenian_ci', '')) -_charsets.add(Charset(165, 'utf32', 'utf32_polish_ci', '')) -_charsets.add(Charset(166, 'utf32', 'utf32_estonian_ci', '')) -_charsets.add(Charset(167, 'utf32', 'utf32_spanish_ci', '')) -_charsets.add(Charset(168, 'utf32', 'utf32_swedish_ci', '')) -_charsets.add(Charset(169, 'utf32', 'utf32_turkish_ci', '')) -_charsets.add(Charset(170, 'utf32', 'utf32_czech_ci', '')) -_charsets.add(Charset(171, 'utf32', 'utf32_danish_ci', '')) -_charsets.add(Charset(172, 'utf32', 'utf32_lithuanian_ci', '')) -_charsets.add(Charset(173, 'utf32', 'utf32_slovak_ci', '')) -_charsets.add(Charset(174, 'utf32', 'utf32_spanish2_ci', '')) -_charsets.add(Charset(175, 'utf32', 'utf32_roman_ci', '')) -_charsets.add(Charset(176, 'utf32', 'utf32_persian_ci', '')) -_charsets.add(Charset(177, 'utf32', 'utf32_esperanto_ci', '')) -_charsets.add(Charset(178, 'utf32', 'utf32_hungarian_ci', '')) -_charsets.add(Charset(179, 'utf32', 'utf32_sinhala_ci', '')) -_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', '')) -_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', '')) -_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', '')) -_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', '')) -_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', '')) -_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', '')) -_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', '')) -_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', '')) -_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', '')) -_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', '')) -_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', '')) -_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', '')) -_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', '')) -_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', '')) -_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', '')) -_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', '')) -_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', '')) -_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', '')) -_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', '')) -_charsets.add(Charset(211, 'utf8', 'utf8_sinhala_ci', '')) -_charsets.add(Charset(223, 'utf8', 'utf8_general_mysql500_ci', '')) -_charsets.add(Charset(224, 'utf8mb4', 'utf8mb4_unicode_ci', '')) -_charsets.add(Charset(225, 'utf8mb4', 'utf8mb4_icelandic_ci', '')) -_charsets.add(Charset(226, 'utf8mb4', 'utf8mb4_latvian_ci', '')) -_charsets.add(Charset(227, 'utf8mb4', 'utf8mb4_romanian_ci', '')) -_charsets.add(Charset(228, 'utf8mb4', 'utf8mb4_slovenian_ci', '')) -_charsets.add(Charset(229, 'utf8mb4', 'utf8mb4_polish_ci', '')) -_charsets.add(Charset(230, 'utf8mb4', 'utf8mb4_estonian_ci', '')) -_charsets.add(Charset(231, 'utf8mb4', 'utf8mb4_spanish_ci', '')) -_charsets.add(Charset(232, 'utf8mb4', 'utf8mb4_swedish_ci', '')) -_charsets.add(Charset(233, 'utf8mb4', 'utf8mb4_turkish_ci', '')) -_charsets.add(Charset(234, 'utf8mb4', 'utf8mb4_czech_ci', '')) -_charsets.add(Charset(235, 'utf8mb4', 'utf8mb4_danish_ci', '')) -_charsets.add(Charset(236, 'utf8mb4', 'utf8mb4_lithuanian_ci', '')) -_charsets.add(Charset(237, 'utf8mb4', 'utf8mb4_slovak_ci', '')) -_charsets.add(Charset(238, 'utf8mb4', 'utf8mb4_spanish2_ci', '')) -_charsets.add(Charset(239, 'utf8mb4', 'utf8mb4_roman_ci', '')) -_charsets.add(Charset(240, 'utf8mb4', 'utf8mb4_persian_ci', '')) -_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', '')) -_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', '')) -_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', '')) -_charsets.add(Charset(244, 'utf8mb4', 'utf8mb4_german2_ci', '')) -_charsets.add(Charset(245, 'utf8mb4', 'utf8mb4_croatian_ci', '')) -_charsets.add(Charset(246, 'utf8mb4', 'utf8mb4_unicode_520_ci', '')) -_charsets.add(Charset(247, 'utf8mb4', 'utf8mb4_vietnamese_ci', '')) - - -charset_by_name = _charsets.by_name -charset_by_id = _charsets.by_id - -def charset_to_encoding(name): - """Convert MySQL's charset name to Python's codec name""" - if name == 'utf8mb4': - return 'utf8' - return name +_charsets.add(Charset(1, "big5", "big5_chinese_ci", True)) +_charsets.add(Charset(2, "latin2", "latin2_czech_cs")) +_charsets.add(Charset(3, "dec8", "dec8_swedish_ci", True)) +_charsets.add(Charset(4, "cp850", "cp850_general_ci", True)) +_charsets.add(Charset(5, "latin1", "latin1_german1_ci")) +_charsets.add(Charset(6, "hp8", "hp8_english_ci", True)) +_charsets.add(Charset(7, "koi8r", "koi8r_general_ci", True)) +_charsets.add(Charset(8, "latin1", "latin1_swedish_ci", True)) +_charsets.add(Charset(9, "latin2", "latin2_general_ci", True)) +_charsets.add(Charset(10, "swe7", "swe7_swedish_ci", True)) +_charsets.add(Charset(11, "ascii", "ascii_general_ci", True)) +_charsets.add(Charset(12, "ujis", "ujis_japanese_ci", True)) +_charsets.add(Charset(13, "sjis", "sjis_japanese_ci", True)) +_charsets.add(Charset(14, "cp1251", "cp1251_bulgarian_ci")) +_charsets.add(Charset(15, "latin1", "latin1_danish_ci")) +_charsets.add(Charset(16, "hebrew", "hebrew_general_ci", True)) +_charsets.add(Charset(18, "tis620", "tis620_thai_ci", True)) +_charsets.add(Charset(19, "euckr", "euckr_korean_ci", True)) +_charsets.add(Charset(20, "latin7", "latin7_estonian_cs")) +_charsets.add(Charset(21, "latin2", "latin2_hungarian_ci")) +_charsets.add(Charset(22, "koi8u", "koi8u_general_ci", True)) +_charsets.add(Charset(23, "cp1251", "cp1251_ukrainian_ci")) +_charsets.add(Charset(24, "gb2312", "gb2312_chinese_ci", True)) +_charsets.add(Charset(25, "greek", "greek_general_ci", True)) +_charsets.add(Charset(26, "cp1250", "cp1250_general_ci", True)) +_charsets.add(Charset(27, "latin2", "latin2_croatian_ci")) +_charsets.add(Charset(28, "gbk", "gbk_chinese_ci", True)) +_charsets.add(Charset(29, "cp1257", "cp1257_lithuanian_ci")) +_charsets.add(Charset(30, "latin5", "latin5_turkish_ci", True)) +_charsets.add(Charset(31, "latin1", "latin1_german2_ci")) +_charsets.add(Charset(32, "armscii8", "armscii8_general_ci", True)) +_charsets.add(Charset(33, "utf8mb3", "utf8mb3_general_ci", True)) +_charsets.add(Charset(34, "cp1250", "cp1250_czech_cs")) +_charsets.add(Charset(36, "cp866", "cp866_general_ci", True)) +_charsets.add(Charset(37, "keybcs2", "keybcs2_general_ci", True)) +_charsets.add(Charset(38, "macce", "macce_general_ci", True)) +_charsets.add(Charset(39, "macroman", "macroman_general_ci", True)) +_charsets.add(Charset(40, "cp852", "cp852_general_ci", True)) +_charsets.add(Charset(41, "latin7", "latin7_general_ci", True)) +_charsets.add(Charset(42, "latin7", "latin7_general_cs")) +_charsets.add(Charset(43, "macce", "macce_bin")) +_charsets.add(Charset(44, "cp1250", "cp1250_croatian_ci")) +_charsets.add(Charset(45, "utf8mb4", "utf8mb4_general_ci", True)) +_charsets.add(Charset(46, "utf8mb4", "utf8mb4_bin")) +_charsets.add(Charset(47, "latin1", "latin1_bin")) +_charsets.add(Charset(48, "latin1", "latin1_general_ci")) +_charsets.add(Charset(49, "latin1", "latin1_general_cs")) +_charsets.add(Charset(50, "cp1251", "cp1251_bin")) +_charsets.add(Charset(51, "cp1251", "cp1251_general_ci", True)) +_charsets.add(Charset(52, "cp1251", "cp1251_general_cs")) +_charsets.add(Charset(53, "macroman", "macroman_bin")) +_charsets.add(Charset(57, "cp1256", "cp1256_general_ci", True)) +_charsets.add(Charset(58, "cp1257", "cp1257_bin")) +_charsets.add(Charset(59, "cp1257", "cp1257_general_ci", True)) +_charsets.add(Charset(63, "binary", "binary", True)) +_charsets.add(Charset(64, "armscii8", "armscii8_bin")) +_charsets.add(Charset(65, "ascii", "ascii_bin")) +_charsets.add(Charset(66, "cp1250", "cp1250_bin")) +_charsets.add(Charset(67, "cp1256", "cp1256_bin")) +_charsets.add(Charset(68, "cp866", "cp866_bin")) +_charsets.add(Charset(69, "dec8", "dec8_bin")) +_charsets.add(Charset(70, "greek", "greek_bin")) +_charsets.add(Charset(71, "hebrew", "hebrew_bin")) +_charsets.add(Charset(72, "hp8", "hp8_bin")) +_charsets.add(Charset(73, "keybcs2", "keybcs2_bin")) +_charsets.add(Charset(74, "koi8r", "koi8r_bin")) +_charsets.add(Charset(75, "koi8u", "koi8u_bin")) +_charsets.add(Charset(76, "utf8mb3", "utf8mb3_tolower_ci")) +_charsets.add(Charset(77, "latin2", "latin2_bin")) +_charsets.add(Charset(78, "latin5", "latin5_bin")) +_charsets.add(Charset(79, "latin7", "latin7_bin")) +_charsets.add(Charset(80, "cp850", "cp850_bin")) +_charsets.add(Charset(81, "cp852", "cp852_bin")) +_charsets.add(Charset(82, "swe7", "swe7_bin")) +_charsets.add(Charset(83, "utf8mb3", "utf8mb3_bin")) +_charsets.add(Charset(84, "big5", "big5_bin")) +_charsets.add(Charset(85, "euckr", "euckr_bin")) +_charsets.add(Charset(86, "gb2312", "gb2312_bin")) +_charsets.add(Charset(87, "gbk", "gbk_bin")) +_charsets.add(Charset(88, "sjis", "sjis_bin")) +_charsets.add(Charset(89, "tis620", "tis620_bin")) +_charsets.add(Charset(91, "ujis", "ujis_bin")) +_charsets.add(Charset(92, "geostd8", "geostd8_general_ci", True)) +_charsets.add(Charset(93, "geostd8", "geostd8_bin")) +_charsets.add(Charset(94, "latin1", "latin1_spanish_ci")) +_charsets.add(Charset(95, "cp932", "cp932_japanese_ci", True)) +_charsets.add(Charset(96, "cp932", "cp932_bin")) +_charsets.add(Charset(97, "eucjpms", "eucjpms_japanese_ci", True)) +_charsets.add(Charset(98, "eucjpms", "eucjpms_bin")) +_charsets.add(Charset(99, "cp1250", "cp1250_polish_ci")) +_charsets.add(Charset(192, "utf8mb3", "utf8mb3_unicode_ci")) +_charsets.add(Charset(193, "utf8mb3", "utf8mb3_icelandic_ci")) +_charsets.add(Charset(194, "utf8mb3", "utf8mb3_latvian_ci")) +_charsets.add(Charset(195, "utf8mb3", "utf8mb3_romanian_ci")) +_charsets.add(Charset(196, "utf8mb3", "utf8mb3_slovenian_ci")) +_charsets.add(Charset(197, "utf8mb3", "utf8mb3_polish_ci")) +_charsets.add(Charset(198, "utf8mb3", "utf8mb3_estonian_ci")) +_charsets.add(Charset(199, "utf8mb3", "utf8mb3_spanish_ci")) +_charsets.add(Charset(200, "utf8mb3", "utf8mb3_swedish_ci")) +_charsets.add(Charset(201, "utf8mb3", "utf8mb3_turkish_ci")) +_charsets.add(Charset(202, "utf8mb3", "utf8mb3_czech_ci")) +_charsets.add(Charset(203, "utf8mb3", "utf8mb3_danish_ci")) +_charsets.add(Charset(204, "utf8mb3", "utf8mb3_lithuanian_ci")) +_charsets.add(Charset(205, "utf8mb3", "utf8mb3_slovak_ci")) +_charsets.add(Charset(206, "utf8mb3", "utf8mb3_spanish2_ci")) +_charsets.add(Charset(207, "utf8mb3", "utf8mb3_roman_ci")) +_charsets.add(Charset(208, "utf8mb3", "utf8mb3_persian_ci")) +_charsets.add(Charset(209, "utf8mb3", "utf8mb3_esperanto_ci")) +_charsets.add(Charset(210, "utf8mb3", "utf8mb3_hungarian_ci")) +_charsets.add(Charset(211, "utf8mb3", "utf8mb3_sinhala_ci")) +_charsets.add(Charset(212, "utf8mb3", "utf8mb3_german2_ci")) +_charsets.add(Charset(213, "utf8mb3", "utf8mb3_croatian_ci")) +_charsets.add(Charset(214, "utf8mb3", "utf8mb3_unicode_520_ci")) +_charsets.add(Charset(215, "utf8mb3", "utf8mb3_vietnamese_ci")) +_charsets.add(Charset(223, "utf8mb3", "utf8mb3_general_mysql500_ci")) +_charsets.add(Charset(224, "utf8mb4", "utf8mb4_unicode_ci")) +_charsets.add(Charset(225, "utf8mb4", "utf8mb4_icelandic_ci")) +_charsets.add(Charset(226, "utf8mb4", "utf8mb4_latvian_ci")) +_charsets.add(Charset(227, "utf8mb4", "utf8mb4_romanian_ci")) +_charsets.add(Charset(228, "utf8mb4", "utf8mb4_slovenian_ci")) +_charsets.add(Charset(229, "utf8mb4", "utf8mb4_polish_ci")) +_charsets.add(Charset(230, "utf8mb4", "utf8mb4_estonian_ci")) +_charsets.add(Charset(231, "utf8mb4", "utf8mb4_spanish_ci")) +_charsets.add(Charset(232, "utf8mb4", "utf8mb4_swedish_ci")) +_charsets.add(Charset(233, "utf8mb4", "utf8mb4_turkish_ci")) +_charsets.add(Charset(234, "utf8mb4", "utf8mb4_czech_ci")) +_charsets.add(Charset(235, "utf8mb4", "utf8mb4_danish_ci")) +_charsets.add(Charset(236, "utf8mb4", "utf8mb4_lithuanian_ci")) +_charsets.add(Charset(237, "utf8mb4", "utf8mb4_slovak_ci")) +_charsets.add(Charset(238, "utf8mb4", "utf8mb4_spanish2_ci")) +_charsets.add(Charset(239, "utf8mb4", "utf8mb4_roman_ci")) +_charsets.add(Charset(240, "utf8mb4", "utf8mb4_persian_ci")) +_charsets.add(Charset(241, "utf8mb4", "utf8mb4_esperanto_ci")) +_charsets.add(Charset(242, "utf8mb4", "utf8mb4_hungarian_ci")) +_charsets.add(Charset(243, "utf8mb4", "utf8mb4_sinhala_ci")) +_charsets.add(Charset(244, "utf8mb4", "utf8mb4_german2_ci")) +_charsets.add(Charset(245, "utf8mb4", "utf8mb4_croatian_ci")) +_charsets.add(Charset(246, "utf8mb4", "utf8mb4_unicode_520_ci")) +_charsets.add(Charset(247, "utf8mb4", "utf8mb4_vietnamese_ci")) +_charsets.add(Charset(248, "gb18030", "gb18030_chinese_ci", True)) +_charsets.add(Charset(249, "gb18030", "gb18030_bin")) +_charsets.add(Charset(250, "gb18030", "gb18030_unicode_520_ci")) +_charsets.add(Charset(255, "utf8mb4", "utf8mb4_0900_ai_ci")) diff --git a/pymysql/connections.py b/pymysql/connections.py index 1e580d21d..3a04ddd68 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -1,12 +1,8 @@ # Python implementation of the MySQL client-server protocol # http://dev.mysql.com/doc/internals/en/client-server-protocol.html # Error codes: -# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html -from __future__ import print_function -from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON - +# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html import errno -import io import os import socket import struct @@ -17,19 +13,23 @@ from . import _auth from .charset import charset_by_name, charset_by_id -from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS +from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS from . import converters from .cursors import Cursor from .optionfile import Parser from .protocol import ( - dump_packet, MysqlPacket, FieldDescriptorPacket, OKPacketWrapper, - EOFPacketWrapper, LoadLocalPacketWrapper + dump_packet, + MysqlPacket, + FieldDescriptorPacket, + OKPacketWrapper, + EOFPacketWrapper, + LoadLocalPacketWrapper, ) -from .util import byte2int, int2byte from . import err, VERSION_STRING try: import ssl + SSL_ENABLED = True except ImportError: ssl = None @@ -37,6 +37,7 @@ try: import getpass + DEFAULT_USER = getpass.getuser() del getpass except (ImportError, KeyError): @@ -45,36 +46,6 @@ DEBUG = False -_py_version = sys.version_info[:2] - -if PY2: - pass -elif _py_version < (3, 6): - # See http://bugs.python.org/issue24870 - _surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)] - - def _fast_surrogateescape(s): - return s.decode('latin1').translate(_surrogateescape_table) -else: - def _fast_surrogateescape(s): - return s.decode('ascii', 'surrogateescape') - -# socket.makefile() in Python 2 is not usable because very inefficient and -# bad behavior about timeout. -# XXX: ._socketio doesn't work under IronPython. -if PY2 and not IRONPYTHON: - # read method of file-like returned by sock.makefile() is very slow. - # So we copy io-based one from Python 3. - from ._socketio import SocketIO - - def _makefile(sock, mode): - return io.BufferedReader(SocketIO(sock, mode)) -else: - # socket.makefile in Python 3 is nice. - def _makefile(sock, mode): - return sock.makefile(mode) - - TEXT_TYPES = { FIELD_TYPE.BIT, FIELD_TYPE.BLOB, @@ -88,32 +59,36 @@ def _makefile(sock, mode): } -DEFAULT_CHARSET = 'utf8mb4' # TODO: change to utf8mb4 +DEFAULT_CHARSET = "utf8mb4" -MAX_PACKET_LEN = 2**24-1 +MAX_PACKET_LEN = 2**24 - 1 -def pack_int24(n): - return struct.pack('`_ in the specification. """ _sock = None - _auth_plugin_name = '' + _auth_plugin_name = "" _closed = False _secure = False - def __init__(self, host=None, user=None, password="", - database=None, port=0, unix_socket=None, - charset='', sql_mode=None, - read_default_file=None, conv=None, use_unicode=None, - client_flag=0, cursorclass=Cursor, init_command=None, - connect_timeout=10, ssl=None, read_default_group=None, - compress=None, named_pipe=None, - autocommit=False, db=None, passwd=None, local_infile=False, - max_allowed_packet=16*1024*1024, defer_connect=False, - auth_plugin_map=None, read_timeout=None, write_timeout=None, - bind_address=None, binary_prefix=False, program_name=None, - server_public_key=None): - if use_unicode is None and sys.version_info[0] > 2: - use_unicode = True - + def __init__( + self, + *, + user=None, # The first four arguments is based on DB-API 2.0 recommendation. + password="", + host=None, + database=None, + unix_socket=None, + port=0, + charset="", + collation=None, + sql_mode=None, + read_default_file=None, + conv=None, + use_unicode=True, + client_flag=0, + cursorclass=Cursor, + init_command=None, + connect_timeout=10, + read_default_group=None, + autocommit=False, + local_infile=False, + max_allowed_packet=16 * 1024 * 1024, + defer_connect=False, + auth_plugin_map=None, + read_timeout=None, + write_timeout=None, + bind_address=None, + binary_prefix=False, + program_name=None, + server_public_key=None, + ssl=None, + ssl_ca=None, + ssl_cert=None, + ssl_disabled=None, + ssl_key=None, + ssl_key_password=None, + ssl_verify_cert=None, + ssl_verify_identity=None, + compress=None, # not supported + named_pipe=None, # not supported + passwd=None, # deprecated + db=None, # deprecated + ): if db is not None and database is None: + # We will raise warning in 2022 or later. + # See https://github.com/PyMySQL/PyMySQL/issues/939 + # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3) database = db if passwd is not None and not password: + # We will raise warning in 2022 or later. + # See https://github.com/PyMySQL/PyMySQL/issues/939 + # warnings.warn( + # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3 + # ) password = passwd if compress or named_pipe: - raise NotImplementedError("compress and named_pipe arguments are not supported") + raise NotImplementedError( + "compress and named_pipe arguments are not supported" + ) self._local_infile = bool(local_infile) if self._local_infile: @@ -240,25 +263,42 @@ def _config(key, arg): if not ssl: ssl = {} if isinstance(ssl, dict): - for key in ["ca", "capath", "cert", "key", "cipher"]: + for key in ["ca", "capath", "cert", "key", "password", "cipher"]: value = _config("ssl-" + key, ssl.get(key)) if value: ssl[key] = value self.ssl = False - if ssl: - if not SSL_ENABLED: - raise NotImplementedError("ssl module not found") - self.ssl = True - client_flag |= CLIENT.SSL - self.ctx = self._create_ssl_ctx(ssl) + if not ssl_disabled: + if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: + ssl = { + "ca": ssl_ca, + "check_hostname": bool(ssl_verify_identity), + "verify_mode": ssl_verify_cert + if ssl_verify_cert is not None + else False, + } + if ssl_cert is not None: + ssl["cert"] = ssl_cert + if ssl_key is not None: + ssl["key"] = ssl_key + if ssl_key_password is not None: + ssl["password"] = ssl_key_password + if ssl: + if not SSL_ENABLED: + raise NotImplementedError("ssl module not found") + self.ssl = True + client_flag |= CLIENT.SSL + self.ctx = self._create_ssl_ctx(ssl) self.host = host or "localhost" self.port = port or 3306 + if type(self.port) is not int: + raise ValueError("port should be of type int") self.user = user or DEFAULT_USER self.password = password or b"" - if isinstance(self.password, text_type): - self.password = self.password.encode('latin1') + if isinstance(self.password, str): + self.password = self.password.encode("latin1") self.db = database self.unix_socket = unix_socket self.bind_address = bind_address @@ -266,20 +306,15 @@ def _config(key, arg): raise ValueError("connect_timeout should be >0 and <=31536000") self.connect_timeout = connect_timeout or None if read_timeout is not None and read_timeout <= 0: - raise ValueError("read_timeout should be >= 0") + raise ValueError("read_timeout should be > 0") self._read_timeout = read_timeout if write_timeout is not None and write_timeout <= 0: - raise ValueError("write_timeout should be >= 0") + raise ValueError("write_timeout should be > 0") self._write_timeout = write_timeout - if charset: - self.charset = charset - self.use_unicode = True - else: - self.charset = DEFAULT_CHARSET - self.use_unicode = False - if use_unicode is not None: - self.use_unicode = use_unicode + self.charset = charset or DEFAULT_CHARSET + self.collation = collation + self.use_unicode = use_unicode self.encoding = charset_by_name(self.charset).encoding @@ -295,15 +330,15 @@ def _config(key, arg): self._affected_rows = 0 self.host_info = "Not connected" - #: specified autocommit mode. None means use server default. + # specified autocommit mode. None means use server default. self.autocommit_mode = autocommit if conv is None: conv = converters.conversions # Need for MySQLdb compatibility. - self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int]) - self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int]) + self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int} + self.decoders = {k: v for (k, v) in conv.items() if type(k) is int} self.sql_mode = sql_mode self.init_command = init_command self.max_allowed_packet = max_allowed_packet @@ -312,33 +347,56 @@ def _config(key, arg): self.server_public_key = server_public_key self._connect_attrs = { - '_client_name': 'pymysql', - '_pid': str(os.getpid()), - '_client_version': VERSION_STRING, + "_client_name": "pymysql", + "_client_version": VERSION_STRING, + "_pid": str(os.getpid()), } + if program_name: self._connect_attrs["program_name"] = program_name - elif sys.argv: - self._connect_attrs["program_name"] = sys.argv[0] if defer_connect: self._sock = None else: self.connect() + def __enter__(self): + return self + + def __exit__(self, *exc_info): + del exc_info + self.close() + def _create_ssl_ctx(self, sslp): if isinstance(sslp, ssl.SSLContext): return sslp - ca = sslp.get('ca') - capath = sslp.get('capath') + ca = sslp.get("ca") + capath = sslp.get("capath") hasnoca = ca is None and capath is None ctx = ssl.create_default_context(cafile=ca, capath=capath) - ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True) - ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED - if 'cert' in sslp: - ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key')) - if 'cipher' in sslp: - ctx.set_ciphers(sslp['cipher']) + ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) + verify_mode_value = sslp.get("verify_mode") + if verify_mode_value is None: + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + elif isinstance(verify_mode_value, bool): + ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE + else: + if isinstance(verify_mode_value, str): + verify_mode_value = verify_mode_value.lower() + if verify_mode_value in ("none", "0", "false", "no"): + ctx.verify_mode = ssl.CERT_NONE + elif verify_mode_value == "optional": + ctx.verify_mode = ssl.CERT_OPTIONAL + elif verify_mode_value in ("required", "1", "true", "yes"): + ctx.verify_mode = ssl.CERT_REQUIRED + else: + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + if "cert" in sslp: + ctx.load_cert_chain( + sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password") + ) + if "cipher" in sslp: + ctx.set_ciphers(sslp["cipher"]) ctx.options |= ssl.OP_NO_SSLv2 ctx.options |= ssl.OP_NO_SSLv3 return ctx @@ -357,7 +415,7 @@ def close(self): self._closed = True if self._sock is None: return - send_data = struct.pack('= 5: + if int(self.server_version.split(".", 1)[0]) >= 5: self.client_flag |= CLIENT.MULTI_RESULTS if self.user is None: raise ValueError("Did not specify a username") charset_id = charset_by_name(self.charset).id - if isinstance(self.user, text_type): + if isinstance(self.user, str): self.user = self.user.encode(self.encoding) - data_init = struct.pack('=5.0) - data += authresp + b'\0' + data += authresp + b"\0" if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: - if isinstance(self.db, text_type): + if isinstance(self.db, str): self.db = self.db.encode(self.encoding) - data += self.db + b'\0' + data += self.db + b"\0" if self.server_capabilities & CLIENT.PLUGIN_AUTH: - data += (plugin_name or b'') + b'\0' + data += (plugin_name or b"") + b"\0" if self.server_capabilities & CLIENT.CONNECT_ATTRS: - connect_attrs = b'' + connect_attrs = b"" for k, v in self._connect_attrs.items(): - k = k.encode('utf8') - connect_attrs += struct.pack('B', len(k)) + k - v = v.encode('utf8') - connect_attrs += struct.pack('B', len(v)) + v - data += struct.pack('B', len(connect_attrs)) + connect_attrs + k = k.encode("utf-8") + connect_attrs += _lenenc_int(len(k)) + k + v = v.encode("utf-8") + connect_attrs += _lenenc_int(len(v)) + v + data += _lenenc_int(len(connect_attrs)) + connect_attrs self.write_packet(data) auth_packet = self._read_packet() @@ -854,17 +959,18 @@ def _request_authentication(self): # if authentication method isn't accepted the first byte # will have the octet 254 if auth_packet.is_auth_switch_request(): - if DEBUG: print("received auth switch") + if DEBUG: + print("received auth switch") # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - auth_packet.read_uint8() # 0xfe packet identifier + auth_packet.read_uint8() # 0xfe packet identifier plugin_name = auth_packet.read_string() - if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None: + if ( + self.server_capabilities & CLIENT.PLUGIN_AUTH + and plugin_name is not None + ): auth_packet = self._process_auth(plugin_name, auth_packet) else: - # send legacy handshake - data = _auth.scramble_old_password(self.password, self.salt) + b'\0' - self.write_packet(data) - auth_packet = self._read_packet() + raise err.OperationalError("received unknown auth switch request") elif auth_packet.is_extra_auth_data(): if DEBUG: print("received extra data") @@ -874,9 +980,12 @@ def _request_authentication(self): elif self._auth_plugin_name == "sha256_password": auth_packet = _auth.sha256_password_auth(self, auth_packet) else: - raise err.OperationalError("Received extra packet for auth method %r", self._auth_plugin_name) + raise err.OperationalError( + "Received extra packet for auth method %r", self._auth_plugin_name + ) - if DEBUG: print("Succeed to auth") + if DEBUG: + print("Succeed to auth") def _process_auth(self, plugin_name, auth_packet): handler = self._get_auth_plugin_handler(plugin_name) @@ -884,20 +993,28 @@ def _process_auth(self, plugin_name, auth_packet): try: return handler.authenticate(auth_packet) except AttributeError: - if plugin_name != b'dialog': - raise err.OperationalError(2059, "Authentication plugin '%s'" - " not loaded: - %r missing authenticate method" % (plugin_name, type(handler))) + if plugin_name != b"dialog": + raise err.OperationalError( + CR.CR_AUTH_PLUGIN_CANNOT_LOAD, + f"Authentication plugin '{plugin_name}'" + f" not loaded: - {type(handler)!r} missing authenticate method", + ) if plugin_name == b"caching_sha2_password": return _auth.caching_sha2_password_auth(self, auth_packet) elif plugin_name == b"sha256_password": return _auth.sha256_password_auth(self, auth_packet) elif plugin_name == b"mysql_native_password": data = _auth.scramble_native_password(self.password, auth_packet.read_all()) + elif plugin_name == b"client_ed25519": + data = _auth.ed25519_password(self.password, auth_packet.read_all()) elif plugin_name == b"mysql_old_password": - data = _auth.scramble_old_password(self.password, auth_packet.read_all()) + b'\0' + data = ( + _auth.scramble_old_password(self.password, auth_packet.read_all()) + + b"\0" + ) elif plugin_name == b"mysql_clear_password": # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html - data = self.password + b'\0' + data = self.password + b"\0" elif plugin_name == b"dialog": pkt = auth_packet while True: @@ -907,27 +1024,39 @@ def _process_auth(self, plugin_name, auth_packet): prompt = pkt.read_all() if prompt == b"Password: ": - self.write_packet(self.password + b'\0') + self.write_packet(self.password + b"\0") elif handler: - resp = 'no response - TypeError within plugin.prompt method' + resp = "no response - TypeError within plugin.prompt method" try: resp = handler.prompt(echo, prompt) - self.write_packet(resp + b'\0') + self.write_packet(resp + b"\0") except AttributeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" \ - " not loaded: - %r missing prompt method" % (plugin_name, handler)) + raise err.OperationalError( + CR.CR_AUTH_PLUGIN_CANNOT_LOAD, + f"Authentication plugin '{plugin_name}'" + f" not loaded: - {handler!r} missing prompt method", + ) except TypeError: - raise err.OperationalError(2061, "Authentication plugin '%s'" \ - " %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt)) + raise err.OperationalError( + CR.CR_AUTH_PLUGIN_ERR, + f"Authentication plugin '{plugin_name}'" + f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}", + ) else: - raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler)) + raise err.OperationalError( + CR.CR_AUTH_PLUGIN_CANNOT_LOAD, + f"Authentication plugin '{plugin_name}' not configured", + ) pkt = self._read_packet() pkt.check_error() if pkt.is_ok_packet() or last: break return pkt else: - raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name) + raise err.OperationalError( + CR.CR_AUTH_PLUGIN_CANNOT_LOAD, + "Authentication plugin '%s' not configured" % plugin_name, + ) self.write_packet(data) pkt = self._read_packet() @@ -937,13 +1066,16 @@ def _process_auth(self, plugin_name, auth_packet): def _get_auth_plugin_handler(self, plugin_name): plugin_class = self._auth_plugin_map.get(plugin_name) if not plugin_class and isinstance(plugin_name, bytes): - plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii')) + plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii")) if plugin_class: try: handler = plugin_class(self) except TypeError: - raise err.OperationalError(2059, "Authentication plugin '%s'" - " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class)) + raise err.OperationalError( + CR.CR_AUTH_PLUGIN_CANNOT_LOAD, + f"Authentication plugin '{plugin_name}'" + f" not loaded: - {plugin_class!r} cannot be constructed with connection object", + ) else: handler = None return handler @@ -966,24 +1098,24 @@ def _get_server_information(self): packet = self._read_packet() data = packet.get_all_data() - self.protocol_version = byte2int(data[i:i+1]) + self.protocol_version = data[i] i += 1 - server_end = data.find(b'\0', i) - self.server_version = data[i:server_end].decode('latin1') + server_end = data.find(b"\0", i) + self.server_version = data[i:server_end].decode("latin1") i = server_end + 1 - self.server_thread_id = struct.unpack('= i + 6: - lang, stat, cap_h, salt_len = struct.unpack('= i + salt_len: # salt_len includes auth_plugin_data_part_1 and filler - self.salt += data[i:i+salt_len] + self.salt += data[i : i + salt_len] i += salt_len - i+=1 + i += 1 # AUTH PLUGIN NAME may appear here. if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: # Due to Bug#59453 the auth-plugin-name is missing the terminating @@ -1017,12 +1151,12 @@ def _get_server_information(self): # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake # didn't use version checks as mariadb is corrected and reports # earlier than those two. - server_end = data.find(b'\0', i) - if server_end < 0: # pragma: no cover - very specific upstream bug + server_end = data.find(b"\0", i) + if server_end < 0: # pragma: no cover - very specific upstream bug # not found \0 and last field so take it all - self._auth_plugin_name = data[i:].decode('utf-8') + self._auth_plugin_name = data[i:].decode("utf-8") else: - self._auth_plugin_name = data[i:server_end].decode('utf-8') + self._auth_plugin_name = data[i:server_end].decode("utf-8") def get_server_info(self): return self.server_version @@ -1039,8 +1173,7 @@ def get_server_info(self): NotSupportedError = err.NotSupportedError -class MySQLResult(object): - +class MySQLResult: def __init__(self, connection): """ :type connection: Connection @@ -1111,7 +1244,8 @@ def _read_ok_packet(self, first_packet): def _read_load_local_packet(self, first_packet): if not self.connection._local_infile: raise RuntimeError( - "**WARN**: Received LOAD_LOCAL packet but local_infile option is false.") + "**WARN**: Received LOAD_LOCAL packet but local_infile option is false." + ) load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) try: @@ -1121,17 +1255,23 @@ def _read_load_local_packet(self, first_packet): raise ok_packet = self.connection._read_packet() - if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error - raise err.OperationalError(2014, "Commands Out of Sync") + if ( + not ok_packet.is_ok_packet() + ): # pragma: no cover - upstream induced protocol error + raise err.OperationalError( + CR.CR_COMMANDS_OUT_OF_SYNC, + "Commands Out of Sync", + ) self._read_ok_packet(ok_packet) def _check_packet_is_eof(self, packet): if not packet.is_eof_packet(): return False - #TODO: Support CLIENT.DEPRECATE_EOF + # TODO: Support CLIENT.DEPRECATE_EOF # 1) Add DEPRECATE_EOF to CAPABILITIES # 2) Mask CAPABILITIES with server_capabilities - # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper + # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: + # use OKPacketWrapper instead of EOFPacketWrapper wp = EOFPacketWrapper(packet) self.warning_count = wp.warning_count self.has_next = wp.has_next @@ -1165,7 +1305,20 @@ def _finish_unbuffered_query(self): # in fact, no way to stop MySQL from sending all the data after # executing a query, so we just spin, and wait for an EOF packet. while self.unbuffered_active: - packet = self.connection._read_packet() + try: + packet = self.connection._read_packet() + except err.OperationalError as e: + if e.args[0] in ( + ER.QUERY_TIMEOUT, + ER.STATEMENT_TIMEOUT, + ): + # if the query timed out we can simply ignore this error + self.unbuffered_active = False + self.connection = None + return + + raise + if self._check_packet_is_eof(packet): self.unbuffered_active = False self.connection = None # release reference to kill cyclic reference. @@ -1195,7 +1348,8 @@ def _read_row_from_packet(self, packet): if data is not None: if encoding is not None: data = data.decode(encoding) - if DEBUG: print("DEBUG: DATA = ", data) + if DEBUG: + print("DEBUG: DATA = ", data) if converter is not None: data = converter(data) row.append(data) @@ -1209,7 +1363,7 @@ def _get_descriptions(self): conn_encoding = self.connection.encoding description = [] - for i in range_type(self.field_count): + for i in range(self.field_count): field = self.connection._read_packet(FieldDescriptorPacket) self.fields.append(field) description.append(field.description()) @@ -1230,21 +1384,22 @@ def _get_descriptions(self): encoding = conn_encoding else: # Integers, Dates and Times, and other basic data is encoded in ascii - encoding = 'ascii' + encoding = "ascii" else: encoding = None converter = self.connection.decoders.get(field_type) if converter is converters.through: converter = None - if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter)) + if DEBUG: + print(f"DEBUG: field={field}, converter={converter}") self.converters.append((encoding, converter)) eof_packet = self.connection._read_packet() - assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF' + assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF" self.description = tuple(description) -class LoadLocalFile(object): +class LoadLocalFile: def __init__(self, filename, connection): self.filename = filename self.connection = connection @@ -1252,19 +1407,25 @@ def __init__(self, filename, connection): def send_data(self): """Send data packets from the local file to the server""" if not self.connection._sock: - raise err.InterfaceError("(0, '')") - conn = self.connection + raise err.InterfaceError(0, "") + conn: Connection = self.connection try: - with open(self.filename, 'rb') as open_file: - packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough + with open(self.filename, "rb") as open_file: + packet_size = min( + conn.max_allowed_packet, 16 * 1024 + ) # 16KB is efficient enough while True: chunk = open_file.read(packet_size) if not chunk: break conn.write_packet(chunk) - except IOError: - raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename)) + except OSError: + raise err.OperationalError( + ER.FILE_NOT_FOUND, + f"Can't find file '{self.filename}'", + ) finally: - # send the empty packet to signify we are done sending data - conn.write_packet(b'') + if not conn._closed: + # send the empty packet to signify we are done sending data + conn.write_packet(b"") diff --git a/pymysql/constants/CLIENT.py b/pymysql/constants/CLIENT.py index b42f1523c..34fe57a5d 100644 --- a/pymysql/constants/CLIENT.py +++ b/pymysql/constants/CLIENT.py @@ -21,9 +21,16 @@ CONNECT_ATTRS = 1 << 20 PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21 CAPABILITIES = ( - LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS - | SECURE_CONNECTION | MULTI_RESULTS - | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS) + LONG_PASSWORD + | LONG_FLAG + | PROTOCOL_41 + | TRANSACTIONS + | SECURE_CONNECTION + | MULTI_RESULTS + | PLUGIN_AUTH + | PLUGIN_AUTH_LENENC_CLIENT_DATA + | CONNECT_ATTRS +) # Not done yet HANDLE_EXPIRED_PASSWORDS = 1 << 22 diff --git a/pymysql/constants/COMMAND.py b/pymysql/constants/COMMAND.py index 1da275533..2d98850b8 100644 --- a/pymysql/constants/COMMAND.py +++ b/pymysql/constants/COMMAND.py @@ -1,4 +1,3 @@ - COM_SLEEP = 0x00 COM_QUIT = 0x01 COM_INIT_DB = 0x02 @@ -9,12 +8,12 @@ COM_REFRESH = 0x07 COM_SHUTDOWN = 0x08 COM_STATISTICS = 0x09 -COM_PROCESS_INFO = 0x0a -COM_CONNECT = 0x0b -COM_PROCESS_KILL = 0x0c -COM_DEBUG = 0x0d -COM_PING = 0x0e -COM_TIME = 0x0f +COM_PROCESS_INFO = 0x0A +COM_CONNECT = 0x0B +COM_PROCESS_KILL = 0x0C +COM_DEBUG = 0x0D +COM_PING = 0x0E +COM_TIME = 0x0F COM_DELAYED_INSERT = 0x10 COM_CHANGE_USER = 0x11 COM_BINLOG_DUMP = 0x12 @@ -25,9 +24,9 @@ COM_STMT_EXECUTE = 0x17 COM_STMT_SEND_LONG_DATA = 0x18 COM_STMT_CLOSE = 0x19 -COM_STMT_RESET = 0x1a -COM_SET_OPTION = 0x1b -COM_STMT_FETCH = 0x1c -COM_DAEMON = 0x1d -COM_BINLOG_DUMP_GTID = 0x1e -COM_END = 0x1f +COM_STMT_RESET = 0x1A +COM_SET_OPTION = 0x1B +COM_STMT_FETCH = 0x1C +COM_DAEMON = 0x1D +COM_BINLOG_DUMP_GTID = 0x1E +COM_END = 0x1F diff --git a/pymysql/constants/CR.py b/pymysql/constants/CR.py index 48ca956ec..deae977e5 100644 --- a/pymysql/constants/CR.py +++ b/pymysql/constants/CR.py @@ -1,68 +1,79 @@ # flake8: noqa # errmsg.h -CR_ERROR_FIRST = 2000 -CR_UNKNOWN_ERROR = 2000 -CR_SOCKET_CREATE_ERROR = 2001 -CR_CONNECTION_ERROR = 2002 -CR_CONN_HOST_ERROR = 2003 -CR_IPSOCK_ERROR = 2004 -CR_UNKNOWN_HOST = 2005 -CR_SERVER_GONE_ERROR = 2006 -CR_VERSION_ERROR = 2007 -CR_OUT_OF_MEMORY = 2008 -CR_WRONG_HOST_INFO = 2009 +CR_ERROR_FIRST = 2000 +CR_UNKNOWN_ERROR = 2000 +CR_SOCKET_CREATE_ERROR = 2001 +CR_CONNECTION_ERROR = 2002 +CR_CONN_HOST_ERROR = 2003 +CR_IPSOCK_ERROR = 2004 +CR_UNKNOWN_HOST = 2005 +CR_SERVER_GONE_ERROR = 2006 +CR_VERSION_ERROR = 2007 +CR_OUT_OF_MEMORY = 2008 +CR_WRONG_HOST_INFO = 2009 CR_LOCALHOST_CONNECTION = 2010 -CR_TCP_CONNECTION = 2011 +CR_TCP_CONNECTION = 2011 CR_SERVER_HANDSHAKE_ERR = 2012 -CR_SERVER_LOST = 2013 +CR_SERVER_LOST = 2013 CR_COMMANDS_OUT_OF_SYNC = 2014 CR_NAMEDPIPE_CONNECTION = 2015 -CR_NAMEDPIPEWAIT_ERROR = 2016 -CR_NAMEDPIPEOPEN_ERROR = 2017 +CR_NAMEDPIPEWAIT_ERROR = 2016 +CR_NAMEDPIPEOPEN_ERROR = 2017 CR_NAMEDPIPESETSTATE_ERROR = 2018 -CR_CANT_READ_CHARSET = 2019 +CR_CANT_READ_CHARSET = 2019 CR_NET_PACKET_TOO_LARGE = 2020 -CR_EMBEDDED_CONNECTION = 2021 -CR_PROBE_SLAVE_STATUS = 2022 -CR_PROBE_SLAVE_HOSTS = 2023 -CR_PROBE_SLAVE_CONNECT = 2024 +CR_EMBEDDED_CONNECTION = 2021 +CR_PROBE_SLAVE_STATUS = 2022 +CR_PROBE_SLAVE_HOSTS = 2023 +CR_PROBE_SLAVE_CONNECT = 2024 CR_PROBE_MASTER_CONNECT = 2025 CR_SSL_CONNECTION_ERROR = 2026 -CR_MALFORMED_PACKET = 2027 -CR_WRONG_LICENSE = 2028 +CR_MALFORMED_PACKET = 2027 +CR_WRONG_LICENSE = 2028 -CR_NULL_POINTER = 2029 -CR_NO_PREPARE_STMT = 2030 -CR_PARAMS_NOT_BOUND = 2031 -CR_DATA_TRUNCATED = 2032 +CR_NULL_POINTER = 2029 +CR_NO_PREPARE_STMT = 2030 +CR_PARAMS_NOT_BOUND = 2031 +CR_DATA_TRUNCATED = 2032 CR_NO_PARAMETERS_EXISTS = 2033 CR_INVALID_PARAMETER_NO = 2034 -CR_INVALID_BUFFER_USE = 2035 +CR_INVALID_BUFFER_USE = 2035 CR_UNSUPPORTED_PARAM_TYPE = 2036 -CR_SHARED_MEMORY_CONNECTION = 2037 -CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = 2038 -CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = 2039 +CR_SHARED_MEMORY_CONNECTION = 2037 +CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = 2038 +CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = 2039 CR_SHARED_MEMORY_CONNECT_FILE_MAP_ERROR = 2040 -CR_SHARED_MEMORY_CONNECT_MAP_ERROR = 2041 -CR_SHARED_MEMORY_FILE_MAP_ERROR = 2042 -CR_SHARED_MEMORY_MAP_ERROR = 2043 -CR_SHARED_MEMORY_EVENT_ERROR = 2044 +CR_SHARED_MEMORY_CONNECT_MAP_ERROR = 2041 +CR_SHARED_MEMORY_FILE_MAP_ERROR = 2042 +CR_SHARED_MEMORY_MAP_ERROR = 2043 +CR_SHARED_MEMORY_EVENT_ERROR = 2044 CR_SHARED_MEMORY_CONNECT_ABANDONED_ERROR = 2045 -CR_SHARED_MEMORY_CONNECT_SET_ERROR = 2046 -CR_CONN_UNKNOW_PROTOCOL = 2047 -CR_INVALID_CONN_HANDLE = 2048 -CR_SECURE_AUTH = 2049 -CR_FETCH_CANCELED = 2050 -CR_NO_DATA = 2051 -CR_NO_STMT_METADATA = 2052 -CR_NO_RESULT_SET = 2053 -CR_NOT_IMPLEMENTED = 2054 -CR_SERVER_LOST_EXTENDED = 2055 -CR_STMT_CLOSED = 2056 -CR_NEW_STMT_METADATA = 2057 -CR_ALREADY_CONNECTED = 2058 -CR_AUTH_PLUGIN_CANNOT_LOAD = 2059 -CR_DUPLICATE_CONNECTION_ATTR = 2060 -CR_AUTH_PLUGIN_ERR = 2061 -CR_ERROR_LAST = 2061 +CR_SHARED_MEMORY_CONNECT_SET_ERROR = 2046 +CR_CONN_UNKNOW_PROTOCOL = 2047 +CR_INVALID_CONN_HANDLE = 2048 +CR_SECURE_AUTH = 2049 +CR_FETCH_CANCELED = 2050 +CR_NO_DATA = 2051 +CR_NO_STMT_METADATA = 2052 +CR_NO_RESULT_SET = 2053 +CR_NOT_IMPLEMENTED = 2054 +CR_SERVER_LOST_EXTENDED = 2055 +CR_STMT_CLOSED = 2056 +CR_NEW_STMT_METADATA = 2057 +CR_ALREADY_CONNECTED = 2058 +CR_AUTH_PLUGIN_CANNOT_LOAD = 2059 +CR_DUPLICATE_CONNECTION_ATTR = 2060 +CR_AUTH_PLUGIN_ERR = 2061 +CR_INSECURE_API_ERR = 2062 +CR_FILE_NAME_TOO_LONG = 2063 +CR_SSL_FIPS_MODE_ERR = 2064 +CR_DEPRECATED_COMPRESSION_NOT_SUPPORTED = 2065 +CR_COMPRESSION_WRONGLY_CONFIGURED = 2066 +CR_KERBEROS_USER_NOT_FOUND = 2067 +CR_LOAD_DATA_LOCAL_INFILE_REJECTED = 2068 +CR_LOAD_DATA_LOCAL_INFILE_REALPATH_FAIL = 2069 +CR_DNS_SRV_LOOKUP_FAILED = 2070 +CR_MANDATORY_TRACKER_NOT_FOUND = 2071 +CR_INVALID_FACTOR_NO = 2072 +CR_ERROR_LAST = 2072 diff --git a/pymysql/constants/ER.py b/pymysql/constants/ER.py index 79b88afbe..98729d12d 100644 --- a/pymysql/constants/ER.py +++ b/pymysql/constants/ER.py @@ -1,4 +1,3 @@ - ERROR_FIRST = 1000 HASHCHK = 1000 NISAMCHK = 1001 @@ -471,5 +470,8 @@ WRONG_STRING_LENGTH = 1468 ERROR_LAST = 1468 +# MariaDB only +STATEMENT_TIMEOUT = 1969 +QUERY_TIMEOUT = 3024 # https://github.com/PyMySQL/PyMySQL/issues/607 CONSTRAINT_FAILED = 4025 diff --git a/pymysql/constants/FIELD_TYPE.py b/pymysql/constants/FIELD_TYPE.py index 51bd5143b..b8b448660 100644 --- a/pymysql/constants/FIELD_TYPE.py +++ b/pymysql/constants/FIELD_TYPE.py @@ -1,5 +1,3 @@ - - DECIMAL = 0 TINY = 1 SHORT = 2 diff --git a/pymysql/constants/SERVER_STATUS.py b/pymysql/constants/SERVER_STATUS.py index 6f5d56630..8f8d77688 100644 --- a/pymysql/constants/SERVER_STATUS.py +++ b/pymysql/constants/SERVER_STATUS.py @@ -1,4 +1,3 @@ - SERVER_STATUS_IN_TRANS = 1 SERVER_STATUS_AUTOCOMMIT = 2 SERVER_MORE_RESULTS_EXISTS = 8 diff --git a/pymysql/converters.py b/pymysql/converters.py index bf1db9d77..dbf97ca75 100644 --- a/pymysql/converters.py +++ b/pymysql/converters.py @@ -1,12 +1,10 @@ -from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON, unichr - import datetime from decimal import Decimal import re import time -from .constants import FIELD_TYPE, FLAG -from .charset import charset_by_id, charset_to_encoding +from .err import ProgrammingError +from .constants import FIELD_TYPE def escape_item(val, charset, mapping=None): @@ -17,7 +15,7 @@ def escape_item(val, charset, mapping=None): # Fallback to default when no encoder found if not encoder: try: - encoder = mapping[text_type] + encoder = mapping[str] except KeyError: raise TypeError("no default type converter defined") @@ -27,12 +25,10 @@ def escape_item(val, charset, mapping=None): val = encoder(val, mapping) return val + def escape_dict(val, charset, mapping=None): - n = {} - for k, v in val.items(): - quoted = escape_item(v, charset, mapping) - n[k] = quoted - return n + raise TypeError("dict can not be used as parameter") + def escape_sequence(val, charset, mapping=None): n = [] @@ -41,87 +37,63 @@ def escape_sequence(val, charset, mapping=None): n.append(quoted) return "(" + ",".join(n) + ")" + def escape_set(val, charset, mapping=None): - return ','.join([escape_item(x, charset, mapping) for x in val]) + return ",".join([escape_item(x, charset, mapping) for x in val]) + def escape_bool(value, mapping=None): return str(int(value)) -def escape_object(value, mapping=None): - return str(value) def escape_int(value, mapping=None): return str(value) -def escape_float(value, mapping=None): - return ('%.15g' % value) - -_escape_table = [unichr(x) for x in range(128)] -_escape_table[0] = u'\\0' -_escape_table[ord('\\')] = u'\\\\' -_escape_table[ord('\n')] = u'\\n' -_escape_table[ord('\r')] = u'\\r' -_escape_table[ord('\032')] = u'\\Z' -_escape_table[ord('"')] = u'\\"' -_escape_table[ord("'")] = u"\\'" - -def _escape_unicode(value, mapping=None): - """escapes *value* without adding quote. - Value should be unicode - """ - return value.translate(_escape_table) +def escape_float(value, mapping=None): + s = repr(value) + if s in ("inf", "-inf", "nan"): + raise ProgrammingError("%s can not be used with MySQL" % s) + if "e" not in s: + s += "e0" + return s -if PY2: - def escape_string(value, mapping=None): - """escape_string escapes *value* but not surround it with quotes. - Value should be bytes or unicode. - """ - if isinstance(value, unicode): - return _escape_unicode(value) - assert isinstance(value, (bytes, bytearray)) - value = value.replace('\\', '\\\\') - value = value.replace('\0', '\\0') - value = value.replace('\n', '\\n') - value = value.replace('\r', '\\r') - value = value.replace('\032', '\\Z') - value = value.replace("'", "\\'") - value = value.replace('"', '\\"') - return value +_escape_table = [chr(x) for x in range(128)] +_escape_table[0] = "\\0" +_escape_table[ord("\\")] = "\\\\" +_escape_table[ord("\n")] = "\\n" +_escape_table[ord("\r")] = "\\r" +_escape_table[ord("\032")] = "\\Z" +_escape_table[ord('"')] = '\\"' +_escape_table[ord("'")] = "\\'" - def escape_bytes_prefixed(value, mapping=None): - assert isinstance(value, (bytes, bytearray)) - return b"_binary'%s'" % escape_string(value) - def escape_bytes(value, mapping=None): - assert isinstance(value, (bytes, bytearray)) - return b"'%s'" % escape_string(value) +def escape_string(value, mapping=None): + """escapes *value* without adding quote. -else: - escape_string = _escape_unicode + Value should be unicode + """ + return value.translate(_escape_table) - # On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow. - # (fixed in Python 3.6, http://bugs.python.org/issue24870) - # Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff. - # We can escape special chars and surrogateescape at once. - _escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)] - def escape_bytes_prefixed(value, mapping=None): - return "_binary'%s'" % value.decode('latin1').translate(_escape_bytes_table) +def escape_bytes_prefixed(value, mapping=None): + return "_binary'%s'" % value.decode("ascii", "surrogateescape").translate( + _escape_table + ) - def escape_bytes(value, mapping=None): - return "'%s'" % value.decode('latin1').translate(_escape_bytes_table) +def escape_bytes(value, mapping=None): + return "'%s'" % value.decode("ascii", "surrogateescape").translate(_escape_table) -def escape_unicode(value, mapping=None): - return u"'%s'" % _escape_unicode(value) def escape_str(value, mapping=None): return "'%s'" % escape_string(str(value), mapping) + def escape_None(value, mapping=None): - return 'NULL' + return "NULL" + def escape_timedelta(obj, mapping=None): seconds = int(obj.seconds) % 60 @@ -133,6 +105,7 @@ def escape_timedelta(obj, mapping=None): fmt = "'{0:02d}:{1:02d}:{2:02d}'" return fmt.format(hours, minutes, seconds, obj.microseconds) + def escape_time(obj, mapping=None): if obj.microsecond: fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" @@ -140,48 +113,61 @@ def escape_time(obj, mapping=None): fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'" return fmt.format(obj) + def escape_datetime(obj, mapping=None): if obj.microsecond: - fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" + fmt = ( + "'{0.year:04}-{0.month:02}-{0.day:02}" + + " {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'" + ) else: fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'" return fmt.format(obj) + def escape_date(obj, mapping=None): fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'" return fmt.format(obj) + def escape_struct_time(obj, mapping=None): return escape_datetime(datetime.datetime(*obj[:6])) + +def Decimal2Literal(o, d): + return format(o, "f") + + def _convert_second_fraction(s): if not s: return 0 # Pad zeros to ensure the fraction length in microseconds - s = s.ljust(6, '0') + s = s.ljust(6, "0") return int(s[:6]) -DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") + +DATETIME_RE = re.compile( + r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?" +) def convert_datetime(obj): """Returns a DATETIME or TIMESTAMP column value as a datetime object: - >>> datetime_or_None('2007-02-25 23:06:20') + >>> convert_datetime('2007-02-25 23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) - >>> datetime_or_None('2007-02-25T23:06:20') + >>> convert_datetime('2007-02-25T23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) - Illegal values are returned as None: - - >>> datetime_or_None('2007-02-31T23:06:20') is None - True - >>> datetime_or_None('0000-00-00 00:00:00') is None - True + Illegal values are returned as str: + >>> convert_datetime('2007-02-31T23:06:20') + '2007-02-31T23:06:20' + >>> convert_datetime('0000-00-00 00:00:00') + '0000-00-00 00:00:00' """ - if not PY2 and isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode("ascii") m = DATETIME_RE.match(obj) if not m: @@ -190,32 +176,33 @@ def convert_datetime(obj): try: groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) - return datetime.datetime(*[ int(x) for x in groups ]) + return datetime.datetime(*[int(x) for x in groups]) except ValueError: return convert_date(obj) + TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") def convert_timedelta(obj): """Returns a TIME column as a timedelta object: - >>> timedelta_or_None('25:06:17') - datetime.timedelta(1, 3977) - >>> timedelta_or_None('-25:06:17') - datetime.timedelta(-2, 83177) + >>> convert_timedelta('25:06:17') + datetime.timedelta(days=1, seconds=3977) + >>> convert_timedelta('-25:06:17') + datetime.timedelta(days=-2, seconds=82423) - Illegal values are returned as None: + Illegal values are returned as string: - >>> timedelta_or_None('random crap') is None - True + >>> convert_timedelta('random crap') + 'random crap' Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but can accept values as (+|-)DD HH:MM:SS. The latter format will not be parsed correctly by this function. """ - if not PY2 and isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode("ascii") m = TIMEDELTA_RE.match(obj) if not m: @@ -227,31 +214,35 @@ def convert_timedelta(obj): negate = -1 if groups[0] else 1 hours, minutes, seconds, microseconds = groups[1:] - tdelta = datetime.timedelta( - hours = int(hours), - minutes = int(minutes), - seconds = int(seconds), - microseconds = int(microseconds) - ) * negate + tdelta = ( + datetime.timedelta( + hours=int(hours), + minutes=int(minutes), + seconds=int(seconds), + microseconds=int(microseconds), + ) + * negate + ) return tdelta except ValueError: return obj + TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") def convert_time(obj): """Returns a TIME column as a time object: - >>> time_or_None('15:06:17') + >>> convert_time('15:06:17') datetime.time(15, 6, 17) - Illegal values are returned as None: + Illegal values are returned as str: - >>> time_or_None('-25:06:17') is None - True - >>> time_or_None('random crap') is None - True + >>> convert_time('-25:06:17') + '-25:06:17' + >>> convert_time('random crap') + 'random crap' Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but can accept values as (+|-)DD HH:MM:SS. The latter format will not @@ -262,8 +253,8 @@ def convert_time(obj): to be treated as time-of-day and not a time offset, then you can use set this function as the converter for FIELD_TYPE.TIME. """ - if not PY2 and isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode("ascii") m = TIME_RE.match(obj) if not m: @@ -273,8 +264,12 @@ def convert_time(obj): groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) hours, minutes, seconds, microseconds = groups - return datetime.time(hour=int(hours), minute=int(minutes), - second=int(seconds), microsecond=int(microseconds)) + return datetime.time( + hour=int(hours), + minute=int(minutes), + second=int(seconds), + microsecond=int(microseconds), + ) except ValueError: return obj @@ -282,70 +277,29 @@ def convert_time(obj): def convert_date(obj): """Returns a DATE column as a date object: - >>> date_or_None('2007-02-26') + >>> convert_date('2007-02-26') datetime.date(2007, 2, 26) - Illegal values are returned as None: - - >>> date_or_None('2007-02-31') is None - True - >>> date_or_None('0000-00-00') is None - True + Illegal values are returned as str: + >>> convert_date('2007-02-31') + '2007-02-31' + >>> convert_date('0000-00-00') + '0000-00-00' """ - if not PY2 and isinstance(obj, (bytes, bytearray)): - obj = obj.decode('ascii') + if isinstance(obj, (bytes, bytearray)): + obj = obj.decode("ascii") try: - return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) + return datetime.date(*[int(x) for x in obj.split("-", 2)]) except ValueError: return obj -def convert_mysql_timestamp(timestamp): - """Convert a MySQL TIMESTAMP to a Timestamp object. - - MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME: - - >>> mysql_timestamp_converter('2007-02-25 22:32:17') - datetime.datetime(2007, 2, 25, 22, 32, 17) - - MySQL < 4.1 uses a big string of numbers: - - >>> mysql_timestamp_converter('20070225223217') - datetime.datetime(2007, 2, 25, 22, 32, 17) - - Illegal values are returned as None: - - >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None - True - >>> mysql_timestamp_converter('00000000000000') is None - True - - """ - if not PY2 and isinstance(timestamp, (bytes, bytearray)): - timestamp = timestamp.decode('ascii') - if timestamp[4] == '-': - return convert_datetime(timestamp) - timestamp += "0"*(14-len(timestamp)) # padding - year, month, day, hour, minute, second = \ - int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ - int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14]) - try: - return datetime.datetime(year, month, day, hour, minute, second) - except ValueError: - return timestamp - -def convert_set(s): - if isinstance(s, (bytes, bytearray)): - return set(s.split(b",")) - return set(s.split(",")) - - def through(x): return x -#def convert_bit(b): +# def convert_bit(b): # b = "\x00" * (8 - len(b)) + b # pad w/ zeroes # return struct.unpack(">Q", b)[0] # @@ -354,28 +308,12 @@ def through(x): convert_bit = through -def convert_characters(connection, field, data): - field_charset = charset_by_id(field.charsetnr).name - encoding = charset_to_encoding(field_charset) - if field.flags & FLAG.SET: - return convert_set(data.decode(encoding)) - if field.flags & FLAG.BINARY: - return data - - if connection.use_unicode: - data = data.decode(encoding) - elif connection.charset != field_charset: - data = data.decode(encoding) - data = data.encode(connection.encoding) - return data - encoders = { bool: escape_bool, int: escape_int, - long_type: escape_int, float: escape_float, str: escape_str, - text_type: escape_unicode, + bytes: escape_bytes, tuple: escape_sequence, list: escape_sequence, set: escape_sequence, @@ -387,11 +325,9 @@ def convert_characters(connection, field, data): datetime.timedelta: escape_timedelta, datetime.time: escape_time, time.struct_time: escape_struct_time, - Decimal: escape_object, + Decimal: Decimal2Literal, } -if not PY2 or JYTHON or IRONPYTHON: - encoders[bytes] = escape_bytes decoders = { FIELD_TYPE.BIT: convert_bit, @@ -403,11 +339,10 @@ def convert_characters(connection, field, data): FIELD_TYPE.LONGLONG: int, FIELD_TYPE.INT24: int, FIELD_TYPE.YEAR: int, - FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp, + FIELD_TYPE.TIMESTAMP: convert_datetime, FIELD_TYPE.DATETIME: convert_datetime, FIELD_TYPE.TIME: convert_timedelta, FIELD_TYPE.DATE: convert_date, - FIELD_TYPE.SET: convert_set, FIELD_TYPE.BLOB: through, FIELD_TYPE.TINY_BLOB: through, FIELD_TYPE.MEDIUM_BLOB: through, @@ -424,3 +359,5 @@ def convert_characters(connection, field, data): conversions = encoders.copy() conversions.update(decoders) Thing2Literal = escape_str + +# Run doctests with `pytest --doctest-modules pymysql/converters.py` diff --git a/pymysql/cursors.py b/pymysql/cursors.py index cc169987b..8be05ca23 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -1,26 +1,22 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, absolute_import -from functools import partial import re import warnings - -from ._compat import range_type, text_type, PY2 from . import err #: Regular expression for :meth:`Cursor.executemany`. -#: executemany only suports simple bulk insert. +#: executemany only supports simple bulk insert. #: You can use it to load large dataset. RE_INSERT_VALUES = re.compile( - r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + - r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + - r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", - re.IGNORECASE | re.DOTALL) + r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + re.IGNORECASE | re.DOTALL, +) -class Cursor(object): +class Cursor: """ - This is the object you use to interact with the database. + This is the object used to interact with the database. Do not create an instance of a Cursor yourself. Call connections.Connection.cursor(). @@ -35,10 +31,9 @@ class Cursor(object): #: Default value of max_allowed_packet is 1048576. max_stmt_length = 1024000 - _defer_warnings = False - def __init__(self, connection): self.connection = connection + self.warning_count = 0 self.description = None self.rownumber = 0 self.rowcount = -1 @@ -46,7 +41,6 @@ def __init__(self, connection): self._executed = None self._result = None self._rows = None - self._warnings_handled = False def close(self): """ @@ -87,12 +81,9 @@ def setoutputsizes(self, *args): """Does nothing, required by DB API.""" def _nextset(self, unbuffered=False): - """Get the next query set""" + """Get the next query set.""" conn = self._get_db() current_result = self._result - # for unbuffered queries warnings are only available once whole result has been read - if unbuffered: - self._show_warnings() if current_result is None or current_result is not conn._result: return None if not current_result.has_next: @@ -106,42 +97,33 @@ def _nextset(self, unbuffered=False): def nextset(self): return self._nextset(False) - def _ensure_bytes(self, x, encoding=None): - if isinstance(x, text_type): - x = x.encode(encoding) - elif isinstance(x, (tuple, list)): - x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x) - return x - def _escape_args(self, args, conn): - ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding) - if isinstance(args, (tuple, list)): - if PY2: - args = tuple(map(ensure_bytes, args)) return tuple(conn.literal(arg) for arg in args) elif isinstance(args, dict): - if PY2: - args = dict((ensure_bytes(key), ensure_bytes(val)) for - (key, val) in args.items()) - return dict((key, conn.literal(val)) for (key, val) in args.items()) + return {key: conn.literal(val) for (key, val) in args.items()} else: # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error - if PY2: - args = ensure_bytes(args) return conn.escape(args) def mogrify(self, query, args=None): """ - Returns the exact string that is sent to the database by calling the + Returns the exact string that would be sent to the database by calling the execute() method. + :param query: Query to mogrify. + :type query: str + + :param args: Parameters used with query. (optional) + :type args: tuple, list or dict + + :return: The query with argument binding applied. + :rtype: str + This method follows the extension to the DB API 2.0 followed by Psycopg. """ conn = self._get_db() - if PY2: # Use bytes on Python 2 always - query = self._ensure_bytes(query, encoding=conn.encoding) if args is not None: query = query % self._escape_args(args, conn) @@ -149,14 +131,15 @@ def mogrify(self, query, args=None): return query def execute(self, query, args=None): - """Execute a query + """Execute a query. - :param str query: Query to execute. + :param query: Query to execute. + :type query: str - :param args: parameters used with query. (optional) + :param args: Parameters used with query. (optional) :type args: tuple, list or dict - :return: Number of affected rows + :return: Number of affected rows. :rtype: int If args is a list or tuple, %s can be used as a placeholder in the query. @@ -172,12 +155,16 @@ def execute(self, query, args=None): return result def executemany(self, query, args): - # type: (str, list) -> int - """Run several data against one query + """Run several data against one query. + + :param query: Query to execute. + :type query: str + + :param args: Sequence of sequences or mappings. It is used as parameter. + :type args: tuple or list - :param query: query to execute on server - :param args: Sequence of sequences or mappings. It is used as parameter. :return: Number of rows affected, if any. + :rtype: int or None This method improves performance on multiple-row INSERT and REPLACE. Otherwise it is equivalent to looping over args with @@ -190,57 +177,58 @@ def executemany(self, query, args): if m: q_prefix = m.group(1) % () q_values = m.group(2).rstrip() - q_postfix = m.group(3) or '' - assert q_values[0] == '(' and q_values[-1] == ')' - return self._do_execute_many(q_prefix, q_values, q_postfix, args, - self.max_stmt_length, - self._get_db().encoding) + q_postfix = m.group(3) or "" + assert q_values[0] == "(" and q_values[-1] == ")" + return self._do_execute_many( + q_prefix, + q_values, + q_postfix, + args, + self.max_stmt_length, + self._get_db().encoding, + ) self.rowcount = sum(self.execute(query, arg) for arg in args) return self.rowcount - def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): + def _do_execute_many( + self, prefix, values, postfix, args, max_stmt_length, encoding + ): conn = self._get_db() escape = self._escape_args - if isinstance(prefix, text_type): + if isinstance(prefix, str): prefix = prefix.encode(encoding) - if PY2 and isinstance(values, text_type): - values = values.encode(encoding) - if isinstance(postfix, text_type): + if isinstance(postfix, str): postfix = postfix.encode(encoding) sql = bytearray(prefix) args = iter(args) v = values % escape(next(args), conn) - if isinstance(v, text_type): - if PY2: - v = v.encode(encoding) - else: - v = v.encode(encoding, 'surrogateescape') + if isinstance(v, str): + v = v.encode(encoding, "surrogateescape") sql += v rows = 0 for arg in args: v = values % escape(arg, conn) - if isinstance(v, text_type): - if PY2: - v = v.encode(encoding) - else: - v = v.encode(encoding, 'surrogateescape') + if isinstance(v, str): + v = v.encode(encoding, "surrogateescape") if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: rows += self.execute(sql + postfix) sql = bytearray(prefix) else: - sql += b',' + sql += b"," sql += v rows += self.execute(sql + postfix) self.rowcount = rows return rows def callproc(self, procname, args=()): - """Execute stored procedure procname with args + """Execute stored procedure procname with args. - procname -- string, name of procedure to execute on server + :param procname: Name of procedure to execute on server. + :type procname: str - args -- Sequence of parameters to use with procedure + :param args: Sequence of parameters to use with procedure. + :type args: tuple or list Returns the original args. @@ -265,20 +253,25 @@ def callproc(self, procname, args=()): """ conn = self._get_db() if args: - fmt = '@_{0}_%d=%s'.format(procname) - self._query('SET %s' % ','.join(fmt % (index, conn.escape(arg)) - for index, arg in enumerate(args))) + fmt = f"@_{procname}_%d=%s" + self._query( + "SET %s" + % ",".join( + fmt % (index, conn.escape(arg)) for index, arg in enumerate(args) + ) + ) self.nextset() - q = "CALL %s(%s)" % (procname, - ','.join(['@_%s_%d' % (procname, i) - for i in range_type(len(args))])) + q = "CALL {}({})".format( + procname, + ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]), + ) self._query(q) self._executed = q return args def fetchone(self): - """Fetch the next row""" + """Fetch the next row.""" self._check_executed() if self._rows is None or self.rownumber >= len(self._rows): return None @@ -287,32 +280,34 @@ def fetchone(self): return result def fetchmany(self, size=None): - """Fetch several rows""" + """Fetch several rows.""" self._check_executed() if self._rows is None: + # Django expects () for EOF. + # https://github.com/django/django/blob/0c1518ee429b01c145cf5b34eab01b0b92f8c246/django/db/backends/mysql/features.py#L8 return () end = self.rownumber + (size or self.arraysize) - result = self._rows[self.rownumber:end] + result = self._rows[self.rownumber : end] self.rownumber = min(end, len(self._rows)) return result def fetchall(self): - """Fetch all the rows""" + """Fetch all the rows.""" self._check_executed() if self._rows is None: - return () + return [] if self.rownumber: - result = self._rows[self.rownumber:] + result = self._rows[self.rownumber :] else: result = self._rows self.rownumber = len(self._rows) return result - def scroll(self, value, mode='relative'): + def scroll(self, value, mode="relative"): self._check_executed() - if mode == 'relative': + if mode == "relative": r = self.rownumber + value - elif mode == 'absolute': + elif mode == "absolute": r = value else: raise err.ProgrammingError("unknown scroll mode %s" % mode) @@ -323,7 +318,6 @@ def scroll(self, value, mode='relative'): def _query(self, q): conn = self._get_db() - self._last_executed = q self._clear_result() conn.query(q) self._do_get_result() @@ -334,6 +328,7 @@ def _clear_result(self): self._result = None self.rowcount = 0 + self.warning_count = 0 self.description = None self.lastrowid = None self._rows = None @@ -344,57 +339,57 @@ def _do_get_result(self): self._result = result = conn._result self.rowcount = result.affected_rows + self.warning_count = result.warning_count self.description = result.description self.lastrowid = result.insert_id self._rows = result.rows - self._warnings_handled = False - - if not self._defer_warnings: - self._show_warnings() - - def _show_warnings(self): - if self._warnings_handled: - return - self._warnings_handled = True - if self._result and (self._result.has_next or not self._result.warning_count): - return - ws = self._get_db().show_warnings() - if ws is None: - return - for w in ws: - msg = w[-1] - if PY2: - if isinstance(msg, unicode): - msg = msg.encode('utf-8', 'replace') - warnings.warn(err.Warning(*w[1:3]), stacklevel=4) def __iter__(self): - return iter(self.fetchone, None) - - Warning = err.Warning - Error = err.Error - InterfaceError = err.InterfaceError - DatabaseError = err.DatabaseError - DataError = err.DataError - OperationalError = err.OperationalError - IntegrityError = err.IntegrityError - InternalError = err.InternalError - ProgrammingError = err.ProgrammingError - NotSupportedError = err.NotSupportedError + return self + def __next__(self): + row = self.fetchone() + if row is None: + raise StopIteration + return row -class DictCursorMixin(object): + def __getattr__(self, name): + # DB-API 2.0 optional extension says these errors can be accessed + # via Connection object. But MySQLdb had defined them on Cursor object. + if name in ( + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", + ): + # Deprecated since v1.1 + warnings.warn( + "PyMySQL errors hould be accessed from `pymysql` package", + DeprecationWarning, + stacklevel=2, + ) + return getattr(err, name) + raise AttributeError(name) + + +class DictCursorMixin: # You can override this to use OrderedDict or other dict-like types. dict_type = dict def _do_get_result(self): - super(DictCursorMixin, self)._do_get_result() + super()._do_get_result() fields = [] if self.description: for f in self._result.fields: name = f.name if name in fields: - name = f.table_name + '.' + name + name = f.table_name + "." + name fields.append(name) self._fields = fields @@ -427,8 +422,6 @@ class SSCursor(Cursor): possible to scroll backwards, as only the current row is held in memory. """ - _defer_warnings = True - def _conv_row(self, row): return row @@ -450,7 +443,6 @@ def close(self): def _query(self, q): conn = self._get_db() - self._last_executed = q self._clear_result() conn.query(q, unbuffered=True) self._do_get_result() @@ -460,15 +452,15 @@ def nextset(self): return self._nextset(unbuffered=True) def read_next(self): - """Read next row""" + """Read next row.""" return self._conv_row(self._result._read_rowdata_packet_unbuffered()) def fetchone(self): - """Fetch next row""" + """Fetch next row.""" self._check_executed() row = self.read_next() if row is None: - self._show_warnings() + self.warning_count = self._result.warning_count return None self.rownumber += 1 return row @@ -489,43 +481,46 @@ def fetchall_unbuffered(self): """ return iter(self.fetchone, None) - def __iter__(self): - return self.fetchall_unbuffered() - def fetchmany(self, size=None): - """Fetch many""" + """Fetch many.""" self._check_executed() if size is None: size = self.arraysize rows = [] - for i in range_type(size): + for i in range(size): row = self.read_next() if row is None: - self._show_warnings() + self.warning_count = self._result.warning_count break rows.append(row) self.rownumber += 1 + if not rows: + # Django expects () for EOF. + # https://github.com/django/django/blob/0c1518ee429b01c145cf5b34eab01b0b92f8c246/django/db/backends/mysql/features.py#L8 + return () return rows - def scroll(self, value, mode='relative'): + def scroll(self, value, mode="relative"): self._check_executed() - if mode == 'relative': + if mode == "relative": if value < 0: raise err.NotSupportedError( - "Backwards scrolling not supported by this cursor") + "Backwards scrolling not supported by this cursor" + ) - for _ in range_type(value): + for _ in range(value): self.read_next() self.rownumber += value - elif mode == 'absolute': + elif mode == "absolute": if value < self.rownumber: raise err.NotSupportedError( - "Backwards scrolling not supported by this cursor") + "Backwards scrolling not supported by this cursor" + ) end = value - self.rownumber - for _ in range_type(end): + for _ in range(end): self.read_next() self.rownumber = value else: diff --git a/pymysql/err.py b/pymysql/err.py index fbc60558e..dac65d3be 100644 --- a/pymysql/err.py +++ b/pymysql/err.py @@ -74,36 +74,77 @@ def _map_error(exc, *errors): error_map[error] = exc -_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR, - ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME, - ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE, - ER.INVALID_GROUP_FUNC_USE, ER.UNSUPPORTED_EXTENSION, - ER.TABLE_MUST_HAVE_COLUMNS, ER.CANT_DO_THIS_DURING_AN_TRANSACTION, - ER.WRONG_DB_NAME, ER.WRONG_COLUMN_NAME, - ) -_map_error(DataError, ER.WARN_DATA_TRUNCATED, ER.WARN_NULL_TO_NOTNULL, - ER.WARN_DATA_OUT_OF_RANGE, ER.NO_DEFAULT, ER.PRIMARY_CANT_HAVE_NULL, - ER.DATA_TOO_LONG, ER.DATETIME_FUNCTION_OVERFLOW) -_map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW, - ER.NO_REFERENCED_ROW_2, ER.ROW_IS_REFERENCED, ER.ROW_IS_REFERENCED_2, - ER.CANNOT_ADD_FOREIGN, ER.BAD_NULL_ERROR) -_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK, - ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE) -_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR, - ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR, - ER.COLUMNACCESS_DENIED_ERROR, ER.CONSTRAINT_FAILED, ER.LOCK_DEADLOCK) +_map_error( + ProgrammingError, + ER.DB_CREATE_EXISTS, + ER.SYNTAX_ERROR, + ER.PARSE_ERROR, + ER.NO_SUCH_TABLE, + ER.WRONG_DB_NAME, + ER.WRONG_TABLE_NAME, + ER.FIELD_SPECIFIED_TWICE, + ER.INVALID_GROUP_FUNC_USE, + ER.UNSUPPORTED_EXTENSION, + ER.TABLE_MUST_HAVE_COLUMNS, + ER.CANT_DO_THIS_DURING_AN_TRANSACTION, + ER.WRONG_DB_NAME, + ER.WRONG_COLUMN_NAME, +) +_map_error( + DataError, + ER.WARN_DATA_TRUNCATED, + ER.WARN_NULL_TO_NOTNULL, + ER.WARN_DATA_OUT_OF_RANGE, + ER.NO_DEFAULT, + ER.PRIMARY_CANT_HAVE_NULL, + ER.DATA_TOO_LONG, + ER.DATETIME_FUNCTION_OVERFLOW, + ER.TRUNCATED_WRONG_VALUE_FOR_FIELD, + ER.ILLEGAL_VALUE_FOR_TYPE, +) +_map_error( + IntegrityError, + ER.DUP_ENTRY, + ER.NO_REFERENCED_ROW, + ER.NO_REFERENCED_ROW_2, + ER.ROW_IS_REFERENCED, + ER.ROW_IS_REFERENCED_2, + ER.CANNOT_ADD_FOREIGN, + ER.BAD_NULL_ERROR, +) +_map_error( + NotSupportedError, + ER.WARNING_NOT_COMPLETE_ROLLBACK, + ER.NOT_SUPPORTED_YET, + ER.FEATURE_DISABLED, + ER.UNKNOWN_STORAGE_ENGINE, +) +_map_error( + OperationalError, + ER.DBACCESS_DENIED_ERROR, + ER.ACCESS_DENIED_ERROR, + ER.CON_COUNT_ERROR, + ER.TABLEACCESS_DENIED_ERROR, + ER.COLUMNACCESS_DENIED_ERROR, + ER.CONSTRAINT_FAILED, + ER.LOCK_DEADLOCK, +) del _map_error, ER def raise_mysql_exception(data): - errno = struct.unpack('= 2 and value[0] == value[-1] == quote: return value[1:-1] return value + def optionxform(self, key): + return key.lower().replace("_", "-") + def get(self, section, option): value = configparser.RawConfigParser.get(self, section, option) return self.__remove_quotes(value) diff --git a/pymysql/protocol.py b/pymysql/protocol.py index 8ccf7c4d7..98fde6d0c 100644 --- a/pymysql/protocol.py +++ b/pymysql/protocol.py @@ -1,12 +1,9 @@ # Python implementation of low level MySQL client-server protocol # http://dev.mysql.com/doc/internals/en/client-server-protocol.html -from __future__ import print_function from .charset import MBLENGTH -from ._compat import PY2, range_type from .constants import FIELD_TYPE, SERVER_STATUS from . import err -from .util import byte2int import struct import sys @@ -23,11 +20,9 @@ def dump_packet(data): # pragma: no cover def printable(data): - if 32 <= byte2int(data) < 127: - if isinstance(data, int): - return chr(data) - return data - return '.' + if 32 <= data < 127: + return chr(data) + return "." try: print("packet length:", len(data)) @@ -37,21 +32,25 @@ def printable(data): print("-" * 66) except ValueError: pass - dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)] + dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)] for d in dump_data: - print(' '.join("{:02X}".format(byte2int(x)) for x in d) + - ' ' * (16 - len(d)) + ' ' * 2 + - ''.join(printable(x) for x in d)) + print( + " ".join(f"{x:02X}" for x in d) + + " " * (16 - len(d)) + + " " * 2 + + "".join(printable(x) for x in d) + ) print("-" * 66) print() -class MysqlPacket(object): +class MysqlPacket: """Representation of a MySQL response packet. Provides an interface for reading/parsing the packet results. """ - __slots__ = ('_position', '_data') + + __slots__ = ("_position", "_data") def __init__(self, data, encoding): self._position = 0 @@ -62,11 +61,12 @@ def get_all_data(self): def read(self, size): """Read the first 'size' bytes in packet and advance cursor past them.""" - result = self._data[self._position:(self._position+size)] + result = self._data[self._position : (self._position + size)] if len(result) != size: - error = ('Result length not requested length:\n' - 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' - % (size, len(result), self._position, len(self._data))) + error = ( + "Result length not requested length:\n" + f"Expected={size}. Actual={len(result)}. Position: {self._position}. Data Length: {len(self._data)}" + ) if DEBUG: print(error) self.dump() @@ -79,7 +79,7 @@ def read_all(self): (Subsequent read() will return errors.) """ - result = self._data[self._position:] + result = self._data[self._position :] self._position = None # ensure no subsequent read() return result @@ -87,8 +87,9 @@ def advance(self, length): """Advance the cursor in data buffer 'length' bytes.""" new_position = self._position + length if new_position < 0 or new_position > len(self._data): - raise Exception('Invalid advance amount (%s) for cursor. ' - 'Position=%s' % (length, new_position)) + raise Exception( + f"Invalid advance amount ({length}) for cursor. Position={new_position}" + ) self._position = new_position def rewind(self, position=0): @@ -106,44 +107,38 @@ def get_bytes(self, position, length=1): No error checking is done. If requesting outside end of buffer an empty string (or string shorter than 'length') may be returned! """ - return self._data[position:(position+length)] - - if PY2: - def read_uint8(self): - result = ord(self._data[self._position]) - self._position += 1 - return result - else: - def read_uint8(self): - result = self._data[self._position] - self._position += 1 - return result + return self._data[position : (position + length)] + + def read_uint8(self): + result = self._data[self._position] + self._position += 1 + return result def read_uint16(self): - result = struct.unpack_from('= 7 + return self._data[0] == 0 and len(self._data) >= 7 def is_eof_packet(self): # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet # Caution: \xFE may be LengthEncodedInteger. # If \xFE is LengthEncodedInteger header, 8bytes followed. - return self._data[0:1] == b'\xfe' and len(self._data) < 9 + return self._data[0] == 0xFE and len(self._data) < 9 def is_auth_switch_request(self): # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest - return self._data[0:1] == b'\xfe' + return self._data[0] == 0xFE def is_extra_auth_data(self): # https://dev.mysql.com/doc/internals/en/successful-authentication.html - return self._data[0:1] == b'\x01' + return self._data[0] == 1 def is_resultset_packet(self): - field_count = ord(self._data[0:1]) + field_count = self._data[0] return 1 <= field_count <= 250 def is_load_local_packet(self): - return self._data[0:1] == b'\xfb' + return self._data[0] == 0xFB def is_error_packet(self): - return self._data[0:1] == b'\xff' + return self._data[0] == 0xFF def check_error(self): if self.is_error_packet(): - self.rewind() - self.advance(1) # field_count == error (we already know that) - errno = self.read_uint16() - if DEBUG: print("errno =", errno) - err.raise_mysql_exception(self._data) + self.raise_for_error() + + def raise_for_error(self): + self.rewind() + self.advance(1) # field_count == error (we already know that) + errno = self.read_uint16() + if DEBUG: + print("errno =", errno) + err.raise_mysql_exception(self._data) def dump(self): dump_packet(self._data) @@ -245,8 +244,13 @@ def _parse_field_descriptor(self, encoding): self.org_table = self.read_length_coded_string().decode(encoding) self.name = self.read_length_coded_string().decode(encoding) self.org_name = self.read_length_coded_string().decode(encoding) - self.charsetnr, self.length, self.type_code, self.flags, self.scale = ( - self.read_struct('= version_tuple + def get_mysql_vendor(self, conn): + server_version = conn.get_server_info() + + if "MariaDB" in server_version: + return "mariadb" + + return "mysql" + _connections = None @property @@ -55,10 +71,12 @@ def connect(self, **params): p = self.databases[0].copy() p.update(params) conn = pymysql.connect(**p) + @self.addCleanup def teardown(): if conn.open: conn.close() + return conn def _teardown_connections(self): @@ -80,7 +98,7 @@ def safe_create_table(self, connection, tablename, ddl, cleanup=True): with warnings.catch_warnings(): warnings.simplefilter("ignore") - cursor.execute("drop table if exists `%s`" % (tablename,)) + cursor.execute(f"drop table if exists `{tablename}`") cursor.execute(ddl) cursor.close() if cleanup: @@ -90,15 +108,5 @@ def drop_table(self, connection, tablename): cursor = connection.cursor() with warnings.catch_warnings(): warnings.simplefilter("ignore") - cursor.execute("drop table if exists `%s`" % (tablename,)) + cursor.execute(f"drop table if exists `{tablename}`") cursor.close() - - def safe_gc_collect(self): - """Ensure cycles are collected via gc. - - Runs additional times on non-CPython platforms. - - """ - gc.collect() - if not CPYTHON: - gc.collect() diff --git a/pymysql/tests/test_DictCursor.py b/pymysql/tests/test_DictCursor.py index 9a0d638b2..4e545792a 100644 --- a/pymysql/tests/test_DictCursor.py +++ b/pymysql/tests/test_DictCursor.py @@ -6,46 +6,50 @@ class TestDictCursor(base.PyMySQLTestCase): - bob = {'name': 'bob', 'age': 21, 'DOB': datetime.datetime(1990, 2, 6, 23, 4, 56)} - jim = {'name': 'jim', 'age': 56, 'DOB': datetime.datetime(1955, 5, 9, 13, 12, 45)} - fred = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)} + bob = {"name": "bob", "age": 21, "DOB": datetime.datetime(1990, 2, 6, 23, 4, 56)} + jim = {"name": "jim", "age": 56, "DOB": datetime.datetime(1955, 5, 9, 13, 12, 45)} + fred = {"name": "fred", "age": 100, "DOB": datetime.datetime(1911, 9, 12, 1, 1, 1)} cursor_type = pymysql.cursors.DictCursor def setUp(self): - super(TestDictCursor, self).setUp() - self.conn = conn = self.connections[0] + super().setUp() + self.conn = conn = self.connect() c = conn.cursor(self.cursor_type) - # create a table ane some data to query + # create a table and some data to query with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists dictcursor") # include in filterwarnings since for unbuffered dict cursor warning for lack of table # will only be propagated at start of next execute() call - c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""") - data = [("bob", 21, "1990-02-06 23:04:56"), - ("jim", 56, "1955-05-09 13:12:45"), - ("fred", 100, "1911-09-12 01:01:01")] + c.execute( + """CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""" + ) + data = [ + ("bob", 21, "1990-02-06 23:04:56"), + ("jim", 56, "1955-05-09 13:12:45"), + ("fred", 100, "1911-09-12 01:01:01"), + ] c.executemany("insert into dictcursor values (%s,%s,%s)", data) def tearDown(self): c = self.conn.cursor() c.execute("drop table dictcursor") - super(TestDictCursor, self).tearDown() + super().tearDown() def _ensure_cursor_expired(self, cursor): pass def test_DictCursor(self): bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy() - #all assert test compare to the structure as would come out from MySQLdb + # all assert test compare to the structure as would come out from MySQLdb conn = self.conn c = conn.cursor(self.cursor_type) # try an update which should return no rows c.execute("update dictcursor set age=20 where name='bob'") - bob['age'] = 20 + bob["age"] = 20 # pull back the single row dict for bob and check c.execute("SELECT * from dictcursor where name='bob'") r = c.fetchone() @@ -55,19 +59,23 @@ def test_DictCursor(self): # same again, but via fetchall => tuple) c.execute("SELECT * from dictcursor where name='bob'") r = c.fetchall() - self.assertEqual([bob], r, "fetch a 1 row result via fetchall failed via DictCursor") + self.assertEqual( + [bob], r, "fetch a 1 row result via fetchall failed via DictCursor" + ) # same test again but iterate over the c.execute("SELECT * from dictcursor where name='bob'") for r in c: - self.assertEqual(bob, r, "fetch a 1 row result via iteration failed via DictCursor") + self.assertEqual( + bob, r, "fetch a 1 row result via iteration failed via DictCursor" + ) # get all 3 row via fetchall c.execute("SELECT * from dictcursor") r = c.fetchall() - self.assertEqual([bob,jim,fred], r, "fetchall failed via DictCursor") - #same test again but do a list comprehension + self.assertEqual([bob, jim, fred], r, "fetchall failed via DictCursor") + # same test again but do a list comprehension c.execute("SELECT * from dictcursor") r = list(c) - self.assertEqual([bob,jim,fred], r, "DictCursor should be iterable") + self.assertEqual([bob, jim, fred], r, "DictCursor should be iterable") # get all 2 row via fetchmany c.execute("SELECT * from dictcursor") r = c.fetchmany(2) @@ -75,12 +83,13 @@ def test_DictCursor(self): self._ensure_cursor_expired(c) def test_custom_dict(self): - class MyDict(dict): pass + class MyDict(dict): + pass class MyDictCursor(self.cursor_type): dict_type = MyDict - keys = ['name', 'age', 'DOB'] + keys = ["name", "age", "DOB"] bob = MyDict([(k, self.bob[k]) for k in keys]) jim = MyDict([(k, self.jim[k]) for k in keys]) fred = MyDict([(k, self.fred[k]) for k in keys]) @@ -93,18 +102,15 @@ class MyDictCursor(self.cursor_type): cur.execute("SELECT * FROM dictcursor") r = cur.fetchall() - self.assertEqual([bob, jim, fred], r, - "fetchall failed via MyDictCursor") + self.assertEqual([bob, jim, fred], r, "fetchall failed via MyDictCursor") cur.execute("SELECT * FROM dictcursor") r = list(cur) - self.assertEqual([bob, jim, fred], r, - "list failed via MyDictCursor") + self.assertEqual([bob, jim, fred], r, "list failed via MyDictCursor") cur.execute("SELECT * FROM dictcursor") r = cur.fetchmany(2) - self.assertEqual([bob, jim], r, - "list failed via MyDictCursor") + self.assertEqual([bob, jim], r, "list failed via MyDictCursor") self._ensure_cursor_expired(cur) @@ -114,6 +120,8 @@ class TestSSDictCursor(TestDictCursor): def _ensure_cursor_expired(self, cursor): list(cursor.fetchall_unbuffered()) + if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/test_SSCursor.py b/pymysql/tests/test_SSCursor.py index 3bbfcfa41..d5e6e2bce 100644 --- a/pymysql/tests/test_SSCursor.py +++ b/pymysql/tests/test_SSCursor.py @@ -1,15 +1,9 @@ -import sys - -try: - from pymysql.tests import base - import pymysql.cursors - from pymysql.constants import CLIENT -except Exception: - # For local testing from top-level directory, without installing - sys.path.append('../pymysql') - from pymysql.tests import base - import pymysql.cursors - from pymysql.constants import CLIENT +import pytest + +from pymysql.tests import base +import pymysql.cursors +from pymysql.constants import CLIENT, ER + class TestSSCursor(base.PyMySQLTestCase): def test_SSCursor(self): @@ -17,35 +11,35 @@ def test_SSCursor(self): conn = self.connect(client_flag=CLIENT.MULTI_STATEMENTS) data = [ - ('America', '', 'America/Jamaica'), - ('America', '', 'America/Los_Angeles'), - ('America', '', 'America/Lima'), - ('America', '', 'America/New_York'), - ('America', '', 'America/Menominee'), - ('America', '', 'America/Havana'), - ('America', '', 'America/El_Salvador'), - ('America', '', 'America/Costa_Rica'), - ('America', '', 'America/Denver'), - ('America', '', 'America/Detroit'),] + ("America", "", "America/Jamaica"), + ("America", "", "America/Los_Angeles"), + ("America", "", "America/Lima"), + ("America", "", "America/New_York"), + ("America", "", "America/Menominee"), + ("America", "", "America/Havana"), + ("America", "", "America/El_Salvador"), + ("America", "", "America/Costa_Rica"), + ("America", "", "America/Denver"), + ("America", "", "America/Detroit"), + ] cursor = conn.cursor(pymysql.cursors.SSCursor) # Create table - cursor.execute('CREATE TABLE tz_data (' - 'region VARCHAR(64),' - 'zone VARCHAR(64),' - 'name VARCHAR(64))') + cursor.execute( + "CREATE TABLE tz_data (region VARCHAR(64), zone VARCHAR(64), name VARCHAR(64))" + ) conn.begin() # Test INSERT for i in data: - cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i) - self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match') + cursor.execute("INSERT INTO tz_data VALUES (%s, %s, %s)", i) + self.assertEqual(conn.affected_rows(), 1, "affected_rows does not match") conn.commit() # Test fetchone() iter = 0 - cursor.execute('SELECT * FROM tz_data') + cursor.execute("SELECT * FROM tz_data") while True: row = cursor.fetchone() if row is None: @@ -53,26 +47,35 @@ def test_SSCursor(self): iter += 1 # Test cursor.rowcount - self.assertEqual(cursor.rowcount, affected_rows, - 'cursor.rowcount != %s' % (str(affected_rows))) + self.assertEqual( + cursor.rowcount, + affected_rows, + "cursor.rowcount != %s" % (str(affected_rows)), + ) # Test cursor.rownumber - self.assertEqual(cursor.rownumber, iter, - 'cursor.rowcount != %s' % (str(iter))) + self.assertEqual( + cursor.rownumber, iter, "cursor.rowcount != %s" % (str(iter)) + ) # Test row came out the same as it went in - self.assertEqual((row in data), True, - 'Row not found in source data') + self.assertEqual((row in data), True, "Row not found in source data") # Test fetchall - cursor.execute('SELECT * FROM tz_data') - self.assertEqual(len(cursor.fetchall()), len(data), - 'fetchall failed. Number of rows does not match') + cursor.execute("SELECT * FROM tz_data") + self.assertEqual( + len(cursor.fetchall()), + len(data), + "fetchall failed. Number of rows does not match", + ) # Test fetchmany - cursor.execute('SELECT * FROM tz_data') - self.assertEqual(len(cursor.fetchmany(2)), 2, - 'fetchmany failed. Number of rows does not match') + cursor.execute("SELECT * FROM tz_data") + self.assertEqual( + len(cursor.fetchmany(2)), + 2, + "fetchmany failed. Number of rows does not match", + ) # So MySQLdb won't throw "Commands out of sync" while True: @@ -81,30 +84,153 @@ def test_SSCursor(self): break # Test update, affected_rows() - cursor.execute('UPDATE tz_data SET zone = %s', ['Foo']) + cursor.execute("UPDATE tz_data SET zone = %s", ["Foo"]) conn.commit() - self.assertEqual(cursor.rowcount, len(data), - 'Update failed. affected_rows != %s' % (str(len(data)))) + self.assertEqual( + cursor.rowcount, + len(data), + "Update failed. affected_rows != %s" % (str(len(data))), + ) # Test executemany - cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data) - self.assertEqual(cursor.rowcount, len(data), - 'executemany failed. cursor.rowcount != %s' % (str(len(data)))) + cursor.executemany("INSERT INTO tz_data VALUES (%s, %s, %s)", data) + self.assertEqual( + cursor.rowcount, + len(data), + "executemany failed. cursor.rowcount != %s" % (str(len(data))), + ) # Test multiple datasets - cursor.execute('SELECT 1; SELECT 2; SELECT 3') - self.assertListEqual(list(cursor), [(1, )]) + cursor.execute("SELECT 1; SELECT 2; SELECT 3") + self.assertListEqual(list(cursor), [(1,)]) self.assertTrue(cursor.nextset()) - self.assertListEqual(list(cursor), [(2, )]) + self.assertListEqual(list(cursor), [(2,)]) self.assertTrue(cursor.nextset()) - self.assertListEqual(list(cursor), [(3, )]) + self.assertListEqual(list(cursor), [(3,)]) self.assertFalse(cursor.nextset()) - cursor.execute('DROP TABLE IF EXISTS tz_data') + cursor.execute("DROP TABLE IF EXISTS tz_data") cursor.close() + def test_execution_time_limit(self): + # this method is similarly implemented in test_cursor + + conn = self.connect() + + # table creation and filling is SSCursor only as it's not provided by self.setUp() + self.safe_create_table( + conn, + "test", + "create table test (data varchar(10))", + ) + with conn.cursor() as cur: + cur.execute( + "insert into test (data) values " + "('row1'), ('row2'), ('row3'), ('row4'), ('row5')" + ) + conn.commit() + + db_type = self.get_mysql_vendor(conn) + + with conn.cursor(pymysql.cursors.SSCursor) as cur: + # MySQL MAX_EXECUTION_TIME takes ms + # MariaDB max_statement_time takes seconds as int/float, introduced in 10.1 + + # this will sleep 0.01 seconds per row + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + + cur.execute(sql) + # unlike Cursor, SSCursor returns a list of tuples here + self.assertEqual( + cur.fetchall(), + [ + ("row1", 0), + ("row2", 0), + ("row3", 0), + ("row4", 0), + ("row5", 0), + ], + ) + + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + cur.execute(sql) + self.assertEqual(cur.fetchone(), ("row1", 0)) + + # this discards the previous unfinished query and raises an + # incomplete unbuffered query warning + with pytest.warns(UserWarning): + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + + # SSCursor will not read the EOF packet until we try to read + # another row. Skipping this will raise an incomplete unbuffered + # query warning in the next cur.execute(). + self.assertEqual(cur.fetchone(), None) + + if db_type == "mysql": + sql = "SELECT /*+ MAX_EXECUTION_TIME(1) */ data, sleep(1) FROM test" + else: + sql = "SET STATEMENT max_statement_time=0.001 FOR SELECT data, sleep(1) FROM test" + with pytest.raises(pymysql.err.OperationalError) as cm: + # in an unbuffered cursor the OperationalError may not show up + # until fetching the entire result + cur.execute(sql) + cur.fetchall() + + if db_type == "mysql": + # this constant was only introduced in MySQL 5.7, not sure + # what was returned before, may have been ER_QUERY_INTERRUPTED + self.assertEqual(cm.value.args[0], ER.QUERY_TIMEOUT) + else: + self.assertEqual(cm.value.args[0], ER.STATEMENT_TIMEOUT) + + # connection should still be fine at this point + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + + def test_warnings(self): + con = self.connect() + cur = con.cursor(pymysql.cursors.SSCursor) + cur.execute("DROP TABLE IF EXISTS `no_exists_table`") + self.assertEqual(cur.warning_count, 1) + + cur.execute("SHOW WARNINGS") + w = cur.fetchone() + self.assertEqual(w[1], ER.BAD_TABLE_ERROR) + self.assertIn( + "no_exists_table", + w[2], + ) + + # ensure unbuffered result is finished + self.assertIsNone(cur.fetchone()) + + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + self.assertIsNone(cur.fetchone()) + + self.assertEqual(cur.warning_count, 0) + + cur.execute("SELECT CAST('abc' AS SIGNED)") + # this ensures fully retrieving the unbuffered result + rows = cur.fetchmany(2) + self.assertEqual(len(rows), 1) + self.assertEqual(cur.warning_count, 1) + + __all__ = ["TestSSCursor"] if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/test_basic.py b/pymysql/tests/test_basic.py index a53373224..0fe13b59d 100644 --- a/pymysql/tests/test_basic.py +++ b/pymysql/tests/test_basic.py @@ -1,15 +1,11 @@ -# coding: utf-8 import datetime import json import time -import warnings -from unittest2 import SkipTest +import pytest -from pymysql import util import pymysql.cursors from pymysql.tests import base -from pymysql.err import ProgrammingError __all__ = ["TestConversion", "TestCursor", "TestBulkInserts"] @@ -17,26 +13,66 @@ class TestConversion(base.PyMySQLTestCase): def test_datatypes(self): - """ test every data type """ - conn = self.connections[0] + """test every data type""" + conn = self.connect() c = conn.cursor() - c.execute("create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)") + c.execute( + """ +create table test_datatypes ( + b bit, + i int, + l bigint, + f real, + s varchar(32), + u varchar(32), + bb blob, + d date, + dt datetime, + ts timestamp, + td time, + t time, + st datetime) +""" + ) try: # insert values - v = (True, -3, 123456789012, 5.7, "hello'\" world", u"Espa\xc3\xb1ol", "binary\x00data".encode(conn.encoding), datetime.date(1988,2,2), datetime.datetime(2014, 5, 15, 7, 45, 57), datetime.timedelta(5,6), datetime.time(16,32), time.localtime()) - c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v) + v = ( + True, + -3, + 123456789012, + 5.7, + "hello'\" world", + "Espa\xc3\xb1ol", + "binary\x00data".encode(conn.encoding), + datetime.date(1988, 2, 2), + datetime.datetime(2014, 5, 15, 7, 45, 57), + datetime.timedelta(5, 6), + datetime.time(16, 32), + time.localtime(), + ) + c.execute( + "insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values" + " (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", + v, + ) c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") r = c.fetchone() - self.assertEqual(util.int2byte(1), r[0]) + self.assertEqual(b"\x01", r[0]) self.assertEqual(v[1:10], r[1:10]) - self.assertEqual(datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10]) + self.assertEqual( + datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10] + ) self.assertEqual(datetime.datetime(*v[-1][:6]), r[-1]) c.execute("delete from test_datatypes") # check nulls - c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", [None] * 12) + c.execute( + "insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st)" + " values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", + [None] * 12, + ) c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") r = c.fetchone() self.assertEqual(tuple([None] * 12), r) @@ -45,30 +81,37 @@ def test_datatypes(self): # check sequences type for seq_type in (tuple, list, set, frozenset): - c.execute("insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)") - seq = seq_type([2,6]) - c.execute("select l from test_datatypes where i in %s order by i", (seq,)) + c.execute( + "insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)" + ) + seq = seq_type([2, 6]) + c.execute( + "select l from test_datatypes where i in %s order by i", (seq,) + ) r = c.fetchall() - self.assertEqual(((4,),(8,)), r) + self.assertEqual(((4,), (8,)), r) c.execute("delete from test_datatypes") finally: c.execute("drop table test_datatypes") def test_dict(self): - """ test dict escaping """ - conn = self.connections[0] + """test dict escaping""" + conn = self.connect() c = conn.cursor() c.execute("create table test_dict (a integer, b integer, c integer)") try: - c.execute("insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", {"a":1,"b":2,"c":3}) + c.execute( + "insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", + {"a": 1, "b": 2, "c": 3}, + ) c.execute("select a,b,c from test_dict") - self.assertEqual((1,2,3), c.fetchone()) + self.assertEqual((1, 2, 3), c.fetchone()) finally: c.execute("drop table test_dict") def test_string(self): - conn = self.connections[0] + conn = self.connect() c = conn.cursor() c.execute("create table test_dict (a text)") test_value = "I am a test string" @@ -80,7 +123,7 @@ def test_string(self): c.execute("drop table test_dict") def test_integer(self): - conn = self.connections[0] + conn = self.connect() c = conn.cursor() c.execute("create table test_dict (a integer)") test_value = 12345 @@ -94,9 +137,10 @@ def test_integer(self): def test_binary(self): """test binary data""" data = bytes(bytearray(range(255))) - conn = self.connections[0] + conn = self.connect() self.safe_create_table( - conn, "test_binary", "create table test_binary (b binary(255))") + conn, "test_binary", "create table test_binary (b binary(255))" + ) with conn.cursor() as c: c.execute("insert into test_binary (b) values (_binary %s)", (data,)) @@ -106,9 +150,8 @@ def test_binary(self): def test_blob(self): """test blob data""" data = bytes(bytearray(range(256)) * 4) - conn = self.connections[0] - self.safe_create_table( - conn, "test_blob", "create table test_blob (b blob)") + conn = self.connect() + self.safe_create_table(conn, "test_blob", "create table test_blob (b blob)") with conn.cursor() as c: c.execute("insert into test_blob (b) values (_binary %s)", (data,)) @@ -116,42 +159,44 @@ def test_blob(self): self.assertEqual(data, c.fetchone()[0]) def test_untyped(self): - """ test conversion of null, empty string """ - conn = self.connections[0] + """test conversion of null, empty string""" + conn = self.connect() c = conn.cursor() c.execute("select null,''") - self.assertEqual((None,u''), c.fetchone()) + self.assertEqual((None, ""), c.fetchone()) c.execute("select '',null") - self.assertEqual((u'',None), c.fetchone()) + self.assertEqual(("", None), c.fetchone()) def test_timedelta(self): - """ test timedelta conversion """ - conn = self.connections[0] + """test timedelta conversion""" + conn = self.connect() c = conn.cursor() - c.execute("select time('12:30'), time('23:12:59'), time('23:12:59.05100'), time('-12:30'), time('-23:12:59'), time('-23:12:59.05100'), time('-00:30')") - self.assertEqual((datetime.timedelta(0, 45000), - datetime.timedelta(0, 83579), - datetime.timedelta(0, 83579, 51000), - -datetime.timedelta(0, 45000), - -datetime.timedelta(0, 83579), - -datetime.timedelta(0, 83579, 51000), - -datetime.timedelta(0, 1800)), - c.fetchone()) + c.execute( + "select time('12:30'), time('23:12:59'), time('23:12:59.05100')," + + " time('-12:30'), time('-23:12:59'), time('-23:12:59.05100'), time('-00:30')" + ) + self.assertEqual( + ( + datetime.timedelta(0, 45000), + datetime.timedelta(0, 83579), + datetime.timedelta(0, 83579, 51000), + -datetime.timedelta(0, 45000), + -datetime.timedelta(0, 83579), + -datetime.timedelta(0, 83579, 51000), + -datetime.timedelta(0, 1800), + ), + c.fetchone(), + ) def test_datetime_microseconds(self): - """ test datetime conversion w microseconds""" + """test datetime conversion w microseconds""" - conn = self.connections[0] - if not self.mysql_server_is(conn, (5, 6, 4)): - raise SkipTest("target backend does not support microseconds") + conn = self.connect() c = conn.cursor() dt = datetime.datetime(2013, 11, 12, 9, 9, 9, 123450) c.execute("create table test_datetime (id int, ts datetime(6))") try: - c.execute( - "insert into test_datetime values (%s, %s)", - (1, dt) - ) + c.execute("insert into test_datetime values (%s, %s)", (1, dt)) c.execute("select ts from test_datetime") self.assertEqual((dt,), c.fetchone()) finally: @@ -164,7 +209,7 @@ class TestCursor(base.PyMySQLTestCase): # compatible with the DB-API 2.0 spec and has not broken # any unit tests for anything we've tried. - #def test_description(self): + # def test_description(self): # """ test description attribute """ # # result is from MySQLdb module # r = (('Host', 254, 11, 60, 60, 0, 0), @@ -206,15 +251,15 @@ class TestCursor(base.PyMySQLTestCase): # ('max_updates', 3, 1, 11, 11, 0, 0), # ('max_connections', 3, 1, 11, 11, 0, 0), # ('max_user_connections', 3, 1, 11, 11, 0, 0)) - # conn = self.connections[0] + # conn = self.connect() # c = conn.cursor() # c.execute("select * from mysql.user") # # self.assertEqual(r, c.description) def test_fetch_no_result(self): - """ test a fetchone() with no rows """ - conn = self.connections[0] + """test a fetchone() with no rows""" + conn = self.connect() c = conn.cursor() c.execute("create table test_nr (b varchar(32))") try: @@ -225,26 +270,26 @@ def test_fetch_no_result(self): c.execute("drop table test_nr") def test_aggregates(self): - """ test aggregate functions """ - conn = self.connections[0] + """test aggregate functions""" + conn = self.connect() c = conn.cursor() try: - c.execute('create table test_aggregates (i integer)') + c.execute("create table test_aggregates (i integer)") for i in range(0, 10): - c.execute('insert into test_aggregates (i) values (%s)', (i,)) - c.execute('select sum(i) from test_aggregates') - r, = c.fetchone() - self.assertEqual(sum(range(0,10)), r) + c.execute("insert into test_aggregates (i) values (%s)", (i,)) + c.execute("select sum(i) from test_aggregates") + (r,) = c.fetchone() + self.assertEqual(sum(range(0, 10)), r) finally: - c.execute('drop table test_aggregates') + c.execute("drop table test_aggregates") def test_single_tuple(self): - """ test a single tuple """ - conn = self.connections[0] + """test a single tuple""" + conn = self.connect() c = conn.cursor() self.safe_create_table( - conn, 'mystuff', - "create table mystuff (id integer primary key)") + conn, "mystuff", "create table mystuff (id integer primary key)" + ) c.execute("insert into mystuff (id) values (1)") c.execute("insert into mystuff (id) values (2)") c.execute("select id from mystuff where id in %s", ((1,),)) @@ -255,82 +300,102 @@ def test_json(self): args = self.databases[0].copy() args["charset"] = "utf8mb4" conn = pymysql.connect(**args) + # MariaDB only has limited JSON support, stores data as longtext + # https://mariadb.com/kb/en/json-data-type/ if not self.mysql_server_is(conn, (5, 7, 0)): - raise SkipTest("JSON type is not supported on MySQL <= 5.6") + pytest.skip("JSON type is only supported on MySQL >= 5.7") - self.safe_create_table(conn, "test_json", """\ + self.safe_create_table( + conn, + "test_json", + """\ create table test_json ( id int not null, json JSON not null, primary key (id) -);""") +);""", + ) cur = conn.cursor() - json_str = u'{"hello": "こんãĢãĄã¯"}' + json_str = '{"hello": "こんãĢãĄã¯"}' cur.execute("INSERT INTO test_json (id, `json`) values (42, %s)", (json_str,)) cur.execute("SELECT `json` from `test_json` WHERE `id`=42") res = cur.fetchone()[0] self.assertEqual(json.loads(res), json.loads(json_str)) - cur.execute("SELECT CAST(%s AS JSON) AS x", (json_str,)) - res = cur.fetchone()[0] - self.assertEqual(json.loads(res), json.loads(json_str)) + if self.get_mysql_vendor(conn) == "mysql": + cur.execute("SELECT CAST(%s AS JSON) AS x", (json_str,)) + res = cur.fetchone()[0] + self.assertEqual(json.loads(res), json.loads(json_str)) class TestBulkInserts(base.PyMySQLTestCase): - cursor_type = pymysql.cursors.DictCursor def setUp(self): - super(TestBulkInserts, self).setUp() - self.conn = conn = self.connections[0] - c = conn.cursor(self.cursor_type) + super().setUp() + self.conn = conn = self.connect() - # create a table ane some data to query - self.safe_create_table(conn, 'bulkinsert', """\ + # create a table and some data to query + self.safe_create_table( + conn, + "bulkinsert", + """\ CREATE TABLE bulkinsert ( -id int(11), +id int, name char(20), age int, height int, PRIMARY KEY (id) ) -""") +""", + ) def _verify_records(self, data): - conn = self.connections[0] + conn = self.connect() cursor = conn.cursor() cursor.execute("SELECT id, name, age, height from bulkinsert") result = cursor.fetchall() self.assertEqual(sorted(data), sorted(result)) def test_bulk_insert(self): - conn = self.connections[0] + conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)] - cursor.executemany("insert into bulkinsert (id, name, age, height) " - "values (%s,%s,%s,%s)", data) + cursor.executemany( + "insert into bulkinsert (id, name, age, height) values (%s,%s,%s,%s)", + data, + ) self.assertEqual( - cursor._last_executed, bytearray( - b"insert into bulkinsert (id, name, age, height) values " - b"(0,'bob',21,123),(1,'jim',56,45),(2,'fred',100,180)")) - cursor.execute('commit') + cursor._executed, + bytearray( + b"insert into bulkinsert (id, name, age, height) values " + b"(0,'bob',21,123),(1,'jim',56,45),(2,'fred',100,180)" + ), + ) + cursor.execute("commit") self._verify_records(data) def test_bulk_insert_multiline_statement(self): - conn = self.connections[0] + conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)] - cursor.executemany("""insert + cursor.executemany( + """insert into bulkinsert (id, name, age, height) values (%s, %s , %s, %s ) - """, data) - self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert + """, + data, + ) + self.assertEqual( + cursor._executed.strip(), + bytearray( + b"""insert into bulkinsert (id, name, age, height) values (0, @@ -339,33 +404,43 @@ def test_bulk_insert_multiline_statement(self): 'jim' , 56, 45 ),(2, 'fred' , 100, -180 )""")) - cursor.execute('commit') +180 )""" + ), + ) + cursor.execute("commit") self._verify_records(data) def test_bulk_insert_single_record(self): - conn = self.connections[0] + conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123)] - cursor.executemany("insert into bulkinsert (id, name, age, height) " - "values (%s,%s,%s,%s)", data) - cursor.execute('commit') + cursor.executemany( + "insert into bulkinsert (id, name, age, height) values (%s,%s,%s,%s)", + data, + ) + cursor.execute("commit") self._verify_records(data) def test_issue_288(self): - """executemany should work with "insert ... on update" """ - conn = self.connections[0] + """executemany should work with "insert ... on update""" + conn = self.connect() cursor = conn.cursor() data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)] - cursor.executemany("""insert + cursor.executemany( + """insert into bulkinsert (id, name, age, height) values (%s, %s , %s, %s ) on duplicate key update age = values(age) - """, data) - self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert + """, + data, + ) + self.assertEqual( + cursor._executed.strip(), + bytearray( + b"""insert into bulkinsert (id, name, age, height) values (0, @@ -375,17 +450,8 @@ def test_issue_288(self): 45 ),(2, 'fred' , 100, 180 ) on duplicate key update -age = values(age)""")) - cursor.execute('commit') +age = values(age)""" + ), + ) + cursor.execute("commit") self._verify_records(data) - - def test_warnings(self): - con = self.connections[0] - cur = con.cursor() - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always") - cur.execute("drop table if exists no_exists_table") - self.assertEqual(len(ws), 1) - self.assertEqual(ws[0].category, pymysql.Warning) - if u"no_exists_table" not in str(ws[0].message): - self.fail("'no_exists_table' not in %s" % (str(ws[0].message),)) diff --git a/pymysql/tests/test_charset.py b/pymysql/tests/test_charset.py new file mode 100644 index 000000000..94e6e1559 --- /dev/null +++ b/pymysql/tests/test_charset.py @@ -0,0 +1,25 @@ +import pymysql.charset + + +def test_utf8(): + utf8mb3 = pymysql.charset.charset_by_name("utf8mb3") + assert utf8mb3.name == "utf8mb3" + assert utf8mb3.collation == "utf8mb3_general_ci" + assert ( + repr(utf8mb3) + == "Charset(id=33, name='utf8mb3', collation='utf8mb3_general_ci')" + ) + + # MySQL 8.0 changed the default collation for utf8mb4. + # But we use old default for compatibility. + utf8mb4 = pymysql.charset.charset_by_name("utf8mb4") + assert utf8mb4.name == "utf8mb4" + assert utf8mb4.collation == "utf8mb4_general_ci" + assert ( + repr(utf8mb4) + == "Charset(id=45, name='utf8mb4', collation='utf8mb4_general_ci')" + ) + + # utf8 is alias of utf8mb4 since MySQL 8.0, and PyMySQL v1.1. + utf8 = pymysql.charset.charset_by_name("utf8") + assert utf8 == utf8mb4 diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 5e95b1c8c..dcf3394c1 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -1,10 +1,11 @@ import datetime -import sys +import ssl +import pytest import time -import unittest2 +from unittest import mock + import pymysql from pymysql.tests import base -from pymysql._compat import text_type from pymysql.constants import CLIENT @@ -27,7 +28,7 @@ def __init__(self, c, user, db, auth=None, authdata=None, password=None): # already exists - TODO need to check the same plugin applies self._created = False try: - c.execute("GRANT SELECT ON %s.* TO %s" % (db, user)) + c.execute(f"GRANT SELECT ON {db}.* TO {user}") self._grant = True except pymysql.err.InternalError: self._grant = False @@ -37,13 +38,12 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): if self._grant: - self._c.execute("REVOKE SELECT ON %s.* FROM %s" % (self._db, self._user)) + self._c.execute(f"REVOKE SELECT ON {self._db}.* FROM {self._user}") if self._created: self._c.execute("DROP USER %s" % self._user) class TestAuthentication(base.PyMySQLTestCase): - socket_auth = False socket_found = False two_questions_found = False @@ -51,36 +51,40 @@ class TestAuthentication(base.PyMySQLTestCase): pam_found = False mysql_old_password_found = False sha256_password_found = False + ed25519_found = False import os - osuser = os.environ.get('USER') + + osuser = os.environ.get("USER") # socket auth requires the current user and for the connection to be a socket # rest do grants @localhost due to incomplete logic - TODO change to @% then db = base.PyMySQLTestCase.databases[0].copy() - socket_auth = db.get('unix_socket') is not None \ - and db.get('host') in ('localhost', '127.0.0.1') + socket_auth = db.get("unix_socket") is not None and db.get("host") in ( + "localhost", + "127.0.0.1", + ) cur = pymysql.connect(**db).cursor() - del db['user'] + del db["user"] cur.execute("SHOW PLUGINS") for r in cur: - if (r[1], r[2]) != (u'ACTIVE', u'AUTHENTICATION'): + if (r[1], r[2]) != ("ACTIVE", "AUTHENTICATION"): continue - if r[3] == u'auth_socket.so': + if r[3] == "auth_socket.so" or r[0] == "unix_socket": socket_plugin_name = r[0] socket_found = True - elif r[3] == u'dialog_examples.so': - if r[0] == 'two_questions': - two_questions_found = True - elif r[0] == 'three_attempts': - three_attempts_found = True - elif r[0] == u'pam': + elif r[3] == "dialog_examples.so": + if r[0] == "two_questions": + two_questions_found = True + elif r[0] == "three_attempts": + three_attempts_found = True + elif r[0] == "pam": pam_found = True - pam_plugin_name = r[3].split('.')[0] - if pam_plugin_name == 'auth_pam': - pam_plugin_name = 'pam' + pam_plugin_name = r[3].split(".")[0] + if pam_plugin_name == "auth_pam": + pam_plugin_name = "pam" # MySQL: authentication_pam # https://dev.mysql.com/doc/refman/5.5/en/pam-authentication-plugin.html @@ -88,306 +92,385 @@ class TestAuthentication(base.PyMySQLTestCase): # https://mariadb.com/kb/en/mariadb/pam-authentication-plugin/ # Names differ but functionality is close - elif r[0] == u'mysql_old_password': + elif r[0] == "mysql_old_password": mysql_old_password_found = True - elif r[0] == u'sha256_password': + elif r[0] == "sha256_password": sha256_password_found = True - #else: + elif r[0] == "ed25519": + ed25519_found = True + # else: # print("plugin: %r" % r[0]) def test_plugin(self): - if not self.mysql_server_is(self.connections[0], (5, 5, 0)): - raise unittest2.SkipTest("MySQL-5.5 required for plugins") - cur = self.connections[0].cursor() - cur.execute("select plugin from mysql.user where concat(user, '@', host)=current_user()") + conn = self.connect() + cur = conn.cursor() + cur.execute( + "select plugin from mysql.user where concat(user, '@', host)=current_user()" + ) for r in cur: - self.assertIn(self.connections[0]._auth_plugin_name, (r[0], 'mysql_native_password')) + self.assertIn(conn._auth_plugin_name, (r[0], "mysql_native_password")) - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipIf(socket_found, "socket plugin already installed") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif(socket_found, reason="socket plugin already installed") def testSocketAuthInstallPlugin(self): # needs plugin. lets install it. - cur = self.connections[0].cursor() + cur = self.connect().cursor() try: cur.execute("install plugin auth_socket soname 'auth_socket.so'") TestAuthentication.socket_found = True - self.socket_plugin_name = 'auth_socket' + self.socket_plugin_name = "auth_socket" self.realtestSocketAuth() except pymysql.err.InternalError: try: cur.execute("install soname 'auth_socket'") TestAuthentication.socket_found = True - self.socket_plugin_name = 'unix_socket' + self.socket_plugin_name = "unix_socket" self.realtestSocketAuth() except pymysql.err.InternalError: TestAuthentication.socket_found = False - raise unittest2.SkipTest('we couldn\'t install the socket plugin') + pytest.skip("we couldn't install the socket plugin") finally: if TestAuthentication.socket_found: cur.execute("uninstall plugin %s" % self.socket_plugin_name) - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipUnless(socket_found, "no socket plugin") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif(not socket_found, reason="no socket plugin") def testSocketAuth(self): self.realtestSocketAuth() def realtestSocketAuth(self): - with TempUser(self.connections[0].cursor(), TestAuthentication.osuser + '@localhost', - self.databases[0]['db'], self.socket_plugin_name) as u: - c = pymysql.connect(user=TestAuthentication.osuser, **self.db) + with TempUser( + self.connect().cursor(), + TestAuthentication.osuser + "@localhost", + self.databases[0]["database"], + self.socket_plugin_name, + ): + pymysql.connect(user=TestAuthentication.osuser, **self.db) - class Dialog(object): - fail=False + class Dialog: + fail = False def __init__(self, con): - self.fail=TestAuthentication.Dialog.fail + self.fail = TestAuthentication.Dialog.fail pass def prompt(self, echo, prompt): if self.fail: - self.fail=False - return b'bad guess at a password' + self.fail = False + return b"bad guess at a password" return self.m.get(prompt) - class DialogHandler(object): - + class DialogHandler: def __init__(self, con): - self.con=con + self.con = con def authenticate(self, pkt): while True: flag = pkt.read_uint8() - echo = (flag & 0x06) == 0x02 + # echo = (flag & 0x06) == 0x02 last = (flag & 0x01) == 0x01 prompt = pkt.read_all() - if prompt == b'Password, please:': - self.con.write_packet(b'stillnotverysecret\0') + if prompt == b"Password, please:": + self.con.write_packet(b"stillnotverysecret\0") else: - self.con.write_packet(b'no idea what to do with this prompt\0') + self.con.write_packet(b"no idea what to do with this prompt\0") pkt = self.con._read_packet() pkt.check_error() if pkt.is_ok_packet() or last: break return pkt - class DefectiveHandler(object): + class DefectiveHandler: def __init__(self, con): - self.con=con - + self.con = con - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipIf(two_questions_found, "two_questions plugin already installed") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif( + two_questions_found, reason="two_questions plugin already installed" + ) def testDialogAuthTwoQuestionsInstallPlugin(self): # needs plugin. lets install it. - cur = self.connections[0].cursor() + cur = self.connect().cursor() try: cur.execute("install plugin two_questions soname 'dialog_examples.so'") TestAuthentication.two_questions_found = True self.realTestDialogAuthTwoQuestions() except pymysql.err.InternalError: - raise unittest2.SkipTest('we couldn\'t install the two_questions plugin') + pytest.skip("we couldn't install the two_questions plugin") finally: if TestAuthentication.two_questions_found: cur.execute("uninstall plugin two_questions") - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipUnless(two_questions_found, "no two questions auth plugin") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif(not two_questions_found, reason="no two questions auth plugin") def testDialogAuthTwoQuestions(self): self.realTestDialogAuthTwoQuestions() def realTestDialogAuthTwoQuestions(self): - TestAuthentication.Dialog.fail=False - TestAuthentication.Dialog.m = {b'Password, please:': b'notverysecret', - b'Are you sure ?': b'yes, of course'} - with TempUser(self.connections[0].cursor(), 'pymysql_2q@localhost', - self.databases[0]['db'], 'two_questions', 'notverysecret') as u: + TestAuthentication.Dialog.fail = False + TestAuthentication.Dialog.m = { + b"Password, please:": b"notverysecret", + b"Are you sure ?": b"yes, of course", + } + with TempUser( + self.connect().cursor(), + "pymysql_2q@localhost", + self.databases[0]["database"], + "two_questions", + "notverysecret", + ): with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_2q', **self.db) - pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) - - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed") + pymysql.connect(user="pymysql_2q", **self.db) + pymysql.connect( + user="pymysql_2q", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db, + ) + + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif( + three_attempts_found, reason="three_attempts plugin already installed" + ) def testDialogAuthThreeAttemptsQuestionsInstallPlugin(self): # needs plugin. lets install it. - cur = self.connections[0].cursor() + cur = self.connect().cursor() try: cur.execute("install plugin three_attempts soname 'dialog_examples.so'") TestAuthentication.three_attempts_found = True self.realTestDialogAuthThreeAttempts() except pymysql.err.InternalError: - raise unittest2.SkipTest('we couldn\'t install the three_attempts plugin') + pytest.skip("we couldn't install the three_attempts plugin") finally: if TestAuthentication.three_attempts_found: cur.execute("uninstall plugin three_attempts") - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipUnless(three_attempts_found, "no three attempts plugin") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif(not three_attempts_found, reason="no three attempts plugin") def testDialogAuthThreeAttempts(self): self.realTestDialogAuthThreeAttempts() def realTestDialogAuthThreeAttempts(self): - TestAuthentication.Dialog.m = {b'Password, please:': b'stillnotverysecret'} - TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all - with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost', - self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u: - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db) + TestAuthentication.Dialog.m = {b"Password, please:": b"stillnotverysecret"} + TestAuthentication.Dialog.fail = ( + True # fail just once. We've got three attempts after all + ) + with TempUser( + self.connect().cursor(), + "pymysql_3a@localhost", + self.databases[0]["database"], + "three_attempts", + "stillnotverysecret", + ): + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db, + ) + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.DialogHandler}, + **self.db, + ) with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db) + pymysql.connect( + user="pymysql_3a", auth_plugin_map={b"dialog": object}, **self.db + ) with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db) + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.DefectiveHandler}, + **self.db, + ) with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db) - TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'} + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"notdialogplugin": TestAuthentication.Dialog}, + **self.db, + ) + TestAuthentication.Dialog.m = {b"Password, please:": b"I do not know"} with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) - TestAuthentication.Dialog.m = {b'Password, please:': None} + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db, + ) + TestAuthentication.Dialog.m = {b"Password, please:": None} with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db) - - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipIf(pam_found, "pam plugin already installed") - @unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required") - @unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required") + pymysql.connect( + user="pymysql_3a", + auth_plugin_map={b"dialog": TestAuthentication.Dialog}, + **self.db, + ) + + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif(pam_found, reason="pam plugin already installed") + @pytest.mark.skipif( + os.environ.get("PASSWORD") is None, reason="PASSWORD env var required" + ) + @pytest.mark.skipif( + os.environ.get("PAMSERVICE") is None, reason="PAMSERVICE env var required" + ) def testPamAuthInstallPlugin(self): # needs plugin. lets install it. - cur = self.connections[0].cursor() + cur = self.connect().cursor() try: cur.execute("install plugin pam soname 'auth_pam.so'") TestAuthentication.pam_found = True self.realTestPamAuth() except pymysql.err.InternalError: - raise unittest2.SkipTest('we couldn\'t install the auth_pam plugin') + pytest.skip("we couldn't install the auth_pam plugin") finally: if TestAuthentication.pam_found: cur.execute("uninstall plugin pam") - - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipUnless(pam_found, "no pam plugin") - @unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required") - @unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif(not pam_found, reason="no pam plugin") + @pytest.mark.skipif( + os.environ.get("PASSWORD") is None, reason="PASSWORD env var required" + ) + @pytest.mark.skipif( + os.environ.get("PAMSERVICE") is None, reason="PAMSERVICE env var required" + ) def testPamAuth(self): self.realTestPamAuth() def realTestPamAuth(self): db = self.db.copy() import os - db['password'] = os.environ.get('PASSWORD') - cur = self.connections[0].cursor() + + db["password"] = os.environ.get("PASSWORD") + cur = self.connect().cursor() try: - cur.execute('show grants for ' + TestAuthentication.osuser + '@localhost') + cur.execute("show grants for " + TestAuthentication.osuser + "@localhost") grants = cur.fetchone()[0] - cur.execute('drop user ' + TestAuthentication.osuser + '@localhost') + cur.execute("drop user " + TestAuthentication.osuser + "@localhost") except pymysql.OperationalError as e: # assuming the user doesn't exist which is ok too self.assertEqual(1045, e.args[0]) grants = None - with TempUser(cur, TestAuthentication.osuser + '@localhost', - self.databases[0]['db'], 'pam', os.environ.get('PAMSERVICE')) as u: + with TempUser( + cur, + TestAuthentication.osuser + "@localhost", + self.databases[0]["database"], + "pam", + os.environ.get("PAMSERVICE"), + ): try: - c = pymysql.connect(user=TestAuthentication.osuser, **db) - db['password'] = 'very bad guess at password' + pymysql.connect(user=TestAuthentication.osuser, **db) + db["password"] = "very bad guess at password" with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user=TestAuthentication.osuser, - auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, - **self.db) + pymysql.connect( + user=TestAuthentication.osuser, + auth_plugin_map={ + b"mysql_cleartext_password": TestAuthentication.DefectiveHandler + }, + **self.db, + ) except pymysql.OperationalError as e: self.assertEqual(1045, e.args[0]) - # we had 'bad guess at password' work with pam. Well at least we get a permission denied here + # we had 'bad guess at password' work with pam. Well at least we get + # a permission denied here with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user=TestAuthentication.osuser, - auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler}, - **self.db) + pymysql.connect( + user=TestAuthentication.osuser, + auth_plugin_map={ + b"mysql_cleartext_password": TestAuthentication.DefectiveHandler + }, + **self.db, + ) if grants: # recreate the user cur.execute(grants) - # select old_password("crummy p\tassword"); - #| old_password("crummy p\tassword") | - #| 2a01785203b08770 | - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipUnless(mysql_old_password_found, "no mysql_old_password plugin") - def testMySQLOldPasswordAuth(self): - if self.mysql_server_is(self.connections[0], (5, 7, 0)): - raise unittest2.SkipTest('Old passwords aren\'t supported in 5.7') - # pymysql.err.OperationalError: (1045, "Access denied for user 'old_pass_user'@'localhost' (using password: YES)") - # from login in MySQL-5.6 - if self.mysql_server_is(self.connections[0], (5, 6, 0)): - raise unittest2.SkipTest('Old passwords don\'t authenticate in 5.6') - db = self.db.copy() - db['password'] = "crummy p\tassword" - with self.connections[0] as c: - # deprecated in 5.6 - if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)): - with self.assertWarns(pymysql.err.Warning) as cm: - c.execute("SELECT OLD_PASSWORD('%s')" % db['password']) - else: - c.execute("SELECT OLD_PASSWORD('%s')" % db['password']) - v = c.fetchone()[0] - self.assertEqual(v, '2a01785203b08770') - # only works in MariaDB and MySQL-5.6 - can't separate out by version - #if self.mysql_server_is(self.connections[0], (5, 5, 0)): - # with TempUser(c, 'old_pass_user@localhost', - # self.databases[0]['db'], 'mysql_old_password', '2a01785203b08770') as u: - # cur = pymysql.connect(user='old_pass_user', **db).cursor() - # cur.execute("SELECT VERSION()") - c.execute("SELECT @@secure_auth") - secure_auth_setting = c.fetchone()[0] - c.execute('set old_passwords=1') - # pymysql.err.Warning: 'pre-4.1 password hash' is deprecated and will be removed in a future release. Please use post-4.1 password hash instead - if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)): - with self.assertWarns(pymysql.err.Warning) as cm: - c.execute('set global secure_auth=0') - else: - c.execute('set global secure_auth=0') - with TempUser(c, 'old_pass_user@localhost', - self.databases[0]['db'], password=db['password']) as u: - cur = pymysql.connect(user='old_pass_user', **db).cursor() - cur.execute("SELECT VERSION()") - c.execute('set global secure_auth=%r' % secure_auth_setting) - - @unittest2.skipUnless(socket_auth, "connection to unix_socket required") - @unittest2.skipUnless(sha256_password_found, "no sha256 password authentication plugin found") + @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required") + @pytest.mark.skipif( + not sha256_password_found, + reason="no sha256 password authentication plugin found", + ) def testAuthSHA256(self): - c = self.connections[0].cursor() - with TempUser(c, 'pymysql_sha256@localhost', - self.databases[0]['db'], 'sha256_password') as u: - if self.mysql_server_is(self.connections[0], (5, 7, 0)): - c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'") - else: - c.execute('SET old_passwords = 2') - c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')") + conn = self.connect() + c = conn.cursor() + with TempUser( + c, + "pymysql_sha256@localhost", + self.databases[0]["database"], + "sha256_password", + ): + c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'") c.execute("FLUSH PRIVILEGES") db = self.db.copy() - db['password'] = "Sh@256Pa33" - # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test. + db["password"] = "Sh@256Pa33" + # Although SHA256 is supported, need the configuration of public key of + # the mysql server. Currently will get error by this test. with self.assertRaises(pymysql.err.OperationalError): - pymysql.connect(user='pymysql_sha256', **db) + pymysql.connect(user="pymysql_sha256", **db) + + @pytest.mark.skipif(not ed25519_found, reason="no ed25519 authention plugin") + def testAuthEd25519(self): + db = self.db.copy() + del db["password"] + conn = self.connect() + c = conn.cursor() + c.execute("select ed25519_password(''), ed25519_password('ed25519_password')") + for r in c: + empty_pass = r[0].decode("ascii") + non_empty_pass = r[1].decode("ascii") + + with TempUser( + c, + "pymysql_ed25519", + self.databases[0]["database"], + "ed25519", + empty_pass, + ): + pymysql.connect(user="pymysql_ed25519", password="", **db) + + with TempUser( + c, + "pymysql_ed25519", + self.databases[0]["database"], + "ed25519", + non_empty_pass, + ): + pymysql.connect(user="pymysql_ed25519", password="ed25519_password", **db) -class TestConnection(base.PyMySQLTestCase): +class TestConnection(base.PyMySQLTestCase): def test_utf8mb4(self): """This test requires MySQL >= 5.5""" arg = self.databases[0].copy() - arg['charset'] = 'utf8mb4' - conn = pymysql.connect(**arg) + arg["charset"] = "utf8mb4" + pymysql.connect(**arg) + + def test_set_character_set(self): + con = self.connect() + cur = con.cursor() + + con.set_character_set("latin1") + cur.execute("SELECT @@character_set_connection") + self.assertEqual(cur.fetchone(), ("latin1",)) + self.assertEqual(con.encoding, "cp1252") + + con.set_character_set("utf8mb4", "utf8mb4_general_ci") + cur.execute("SELECT @@character_set_connection, @@collation_connection") + self.assertEqual(cur.fetchone(), ("utf8mb4", "utf8mb4_general_ci")) + self.assertEqual(con.encoding, "utf8") def test_largedata(self): """Large query and response (>=16MB)""" - cur = self.connections[0].cursor() + cur = self.connect().cursor() cur.execute("SELECT @@max_allowed_packet") - if cur.fetchone()[0] < 16*1024*1024 + 10: + if cur.fetchone()[0] < 16 * 1024 * 1024 + 10: print("Set max_allowed_packet to bigger than 17MB") return - t = 'a' * (16*1024*1024) + t = "a" * (16 * 1024 * 1024) cur.execute("SELECT '" + t + "'") assert cur.fetchone()[0] == t def test_autocommit(self): - con = self.connections[0] + con = self.connect() self.assertFalse(con.get_autocommit()) cur = con.cursor() @@ -400,16 +483,16 @@ def test_autocommit(self): self.assertEqual(cur.fetchone()[0], 0) def test_select_db(self): - con = self.connections[0] - current_db = self.databases[0]['db'] - other_db = self.databases[1]['db'] + con = self.connect() + current_db = self.databases[0]["database"] + other_db = self.databases[1]["database"] cur = con.cursor() - cur.execute('SELECT database()') + cur.execute("SELECT database()") self.assertEqual(cur.fetchone()[0], current_db) con.select_db(other_db) - cur.execute('SELECT database()') + cur.execute("SELECT database()") self.assertEqual(cur.fetchone()[0], other_db) def test_connection_gone_away(self): @@ -423,49 +506,31 @@ def test_connection_gone_away(self): time.sleep(2) with self.assertRaises(pymysql.OperationalError) as cm: cur.execute("SELECT 1+1") - # error occures while reading, not writing because of socket buffer. - #self.assertEqual(cm.exception.args[0], 2006) + # error occurs while reading, not writing because of socket buffer. + # self.assertEqual(cm.exception.args[0], 2006) self.assertIn(cm.exception.args[0], (2006, 2013)) def test_init_command(self): conn = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) c = conn.cursor() c.execute('select "foobar";') - self.assertEqual(('foobar',), c.fetchone()) + self.assertEqual(("foobar",), c.fetchone()) conn.close() with self.assertRaises(pymysql.err.Error): conn.ping(reconnect=False) def test_read_default_group(self): conn = self.connect( - read_default_group='client', + read_default_group="client", ) self.assertTrue(conn.open) - def test_context(self): - with self.assertRaises(ValueError): - c = self.connect() - with c as cur: - cur.execute('create table test ( a int ) ENGINE=InnoDB') - c.begin() - cur.execute('insert into test values ((1))') - raise ValueError('pseudo abort') - c.commit() - c = self.connect() - with c as cur: - cur.execute('select count(*) from test') - self.assertEqual(0, cur.fetchone()[0]) - cur.execute('insert into test values ((1))') - with c as cur: - cur.execute('select count(*) from test') - self.assertEqual(1,cur.fetchone()[0]) - cur.execute('drop table test') - def test_set_charset(self): c = self.connect() - c.set_charset('utf8mb4') + c.set_charset("utf8mb4") # TODO validate setting here def test_defer_connect(self): @@ -474,12 +539,13 @@ def test_defer_connect(self): d = self.databases[0].copy() try: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(d['unix_socket']) + sock.connect(d["unix_socket"]) except KeyError: sock.close() sock = socket.create_connection( - (d.get('host', 'localhost'), d.get('port', 3306))) - for k in ['unix_socket', 'host', 'port']: + (d.get("host", "localhost"), d.get("port", 3306)) + ) + for k in ["unix_socket", "host", "port"]: try: del d[k] except KeyError: @@ -491,9 +557,250 @@ def test_defer_connect(self): c.close() sock.close() + def test_ssl_connect(self): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + "cipher": "cipher", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_called_with("cipher") + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + "password": "password", + }, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password="password", + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_not_called + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_verify_cert in (True, "1", "yes", "true"): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_verify_cert in (None, False, "0", "no", "false"): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + for ssl_ca in ("ca", None): + for ssl_verify_cert in ("foo", "bar", ""): + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_ca=ssl_ca, + ssl_cert="cert", + ssl_key="key", + ssl_verify_cert=ssl_verify_cert, + ) + assert create_default_context.called + assert not dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ( + ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE + ), (ssl_ca, ssl_verify_cert) + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ssl_verify_identity=True, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password=None, + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ssl_key_password="password", + ssl_verify_identity=True, + ) + assert create_default_context.called + assert dummy_ssl_context.check_hostname + assert dummy_ssl_context.verify_mode == ssl.CERT_NONE + dummy_ssl_context.load_cert_chain.assert_called_with( + "cert", + keyfile="key", + password="password", + ) + dummy_ssl_context.set_ciphers.assert_not_called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_disabled=True, + ssl={ + "ca": "ca", + "cert": "cert", + "key": "key", + }, + ) + assert not create_default_context.called + + dummy_ssl_context = mock.Mock(options=0) + with mock.patch("pymysql.connections.Connection.connect"), mock.patch( + "pymysql.connections.ssl.create_default_context", + new=mock.Mock(return_value=dummy_ssl_context), + ) as create_default_context: + pymysql.connect( + ssl_disabled=True, + ssl_ca="ca", + ssl_cert="cert", + ssl_key="key", + ) + assert not create_default_context.called + # A custom type and function to escape it -class Foo(object): +class Foo: value = "bar" @@ -503,7 +810,7 @@ def escape_foo(x, d): class TestEscape(base.PyMySQLTestCase): def test_escape_string(self): - con = self.connections[0] + con = self.connect() cur = con.cursor() self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'") @@ -516,46 +823,43 @@ def test_escape_string(self): self.assertEqual(con.escape("foo'bar"), "'foo''bar'") def test_escape_builtin_encoders(self): - con = self.connections[0] - cur = con.cursor() + con = self.connect() val = datetime.datetime(2012, 3, 4, 5, 6) self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'") def test_escape_custom_object(self): - con = self.connections[0] - cur = con.cursor() + con = self.connect() mapping = {Foo: escape_foo} self.assertEqual(con.escape(Foo(), mapping), "bar") def test_escape_fallback_encoder(self): - con = self.connections[0] - cur = con.cursor() + con = self.connect() class Custom(str): pass - mapping = {text_type: pymysql.escape_string} - self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'") + mapping = {str: pymysql.converters.escape_string} + self.assertEqual(con.escape(Custom("foobar"), mapping), "'foobar'") def test_escape_no_default(self): - con = self.connections[0] - cur = con.cursor() + con = self.connect() self.assertRaises(TypeError, con.escape, 42, {}) - def test_escape_dict_value(self): - con = self.connections[0] - cur = con.cursor() + def test_escape_dict_raise_typeerror(self): + """con.escape(dict) should raise TypeError""" + con = self.connect() mapping = con.encoders.copy() mapping[Foo] = escape_foo - self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"}) + #self.assertEqual(con.escape({"foo": Foo()}, mapping), {"foo": "bar"}) + with self.assertRaises(TypeError): + con.escape({"foo": Foo()}) def test_escape_list_item(self): - con = self.connections[0] - cur = con.cursor() + con = self.connect() mapping = con.encoders.copy() mapping[Foo] = escape_foo @@ -564,7 +868,8 @@ def test_escape_list_item(self): def test_previous_cursor_not_closed(self): con = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) cur1 = con.cursor() cur1.execute("SELECT 1; SELECT 2") cur2 = con.cursor() diff --git a/pymysql/tests/test_converters.py b/pymysql/tests/test_converters.py index b7b5a9846..b36ee4b39 100644 --- a/pymysql/tests/test_converters.py +++ b/pymysql/tests/test_converters.py @@ -1,7 +1,5 @@ import datetime from unittest import TestCase - -from pymysql._compat import PY2 from pymysql import converters @@ -9,41 +7,30 @@ class TestConverter(TestCase): - def test_escape_string(self): - self.assertEqual( - converters.escape_string(u"foo\nbar"), - u"foo\\nbar" - ) - - if PY2: - def test_escape_string_bytes(self): - self.assertEqual( - converters.escape_string(b"foo\nbar"), - b"foo\\nbar" - ) + self.assertEqual(converters.escape_string("foo\nbar"), "foo\\nbar") def test_convert_datetime(self): expected = datetime.datetime(2007, 2, 24, 23, 6, 20) - dt = converters.convert_datetime('2007-02-24 23:06:20') + dt = converters.convert_datetime("2007-02-24 23:06:20") self.assertEqual(dt, expected) def test_convert_datetime_with_fsp(self): expected = datetime.datetime(2007, 2, 24, 23, 6, 20, 511581) - dt = converters.convert_datetime('2007-02-24 23:06:20.511581') + dt = converters.convert_datetime("2007-02-24 23:06:20.511581") self.assertEqual(dt, expected) def _test_convert_timedelta(self, with_negate=False, with_fsp=False): - d = {'hours': 789, 'minutes': 12, 'seconds': 34} - s = '%(hours)s:%(minutes)s:%(seconds)s' % d + d = {"hours": 789, "minutes": 12, "seconds": 34} + s = "%(hours)s:%(minutes)s:%(seconds)s" % d if with_fsp: - d['microseconds'] = 511581 - s += '.%(microseconds)s' % d + d["microseconds"] = 511581 + s += ".%(microseconds)s" % d expected = datetime.timedelta(**d) if with_negate: expected = -expected - s = '-' + s + s = "-" + s tdelta = converters.convert_timedelta(s) self.assertEqual(tdelta, expected) @@ -58,10 +45,10 @@ def test_convert_timedelta_with_fsp(self): def test_convert_time(self): expected = datetime.time(23, 6, 20) - time_obj = converters.convert_time('23:06:20') + time_obj = converters.convert_time("23:06:20") self.assertEqual(time_obj, expected) def test_convert_time_with_fsp(self): expected = datetime.time(23, 6, 20, 511581) - time_obj = converters.convert_time('23:06:20.511581') + time_obj = converters.convert_time("23:06:20.511581") self.assertEqual(time_obj, expected) diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py index add047550..2e267fb6a 100644 --- a/pymysql/tests/test_cursor.py +++ b/pymysql/tests/test_cursor.py @@ -1,25 +1,37 @@ -import warnings - +from pymysql.constants import ER from pymysql.tests import base import pymysql.cursors +import pytest + + class CursorTest(base.PyMySQLTestCase): def setUp(self): - super(CursorTest, self).setUp() + super().setUp() - conn = self.connections[0] + conn = self.connect() self.safe_create_table( conn, - "test", "create table test (data varchar(10))", + "test", + "create table test (data varchar(10))", ) cursor = conn.cursor() cursor.execute( - "insert into test (data) values " - "('row1'), ('row2'), ('row3'), ('row4'), ('row5')") + "insert into test (data) values ('row1'), ('row2'), ('row3'), ('row4'), ('row5')" + ) + conn.commit() cursor.close() self.test_connection = pymysql.connect(**self.databases[0]) self.addCleanup(self.test_connection.close) + def test_cursor_is_iterator(self): + """Test that the cursor is an iterator""" + conn = self.test_connection + cursor = conn.cursor() + cursor.execute("select * from test") + self.assertEqual(cursor.__iter__(), cursor) + self.assertEqual(cursor.__next__(), ("row1",)) + def test_cleanup_rows_unbuffered(self): conn = self.test_connection cursor = conn.cursor(pymysql.cursors.SSCursor) @@ -30,7 +42,6 @@ def test_cleanup_rows_unbuffered(self): break del cursor - self.safe_gc_collect() c2 = conn.cursor() @@ -48,61 +59,164 @@ def test_cleanup_rows_buffered(self): break del cursor - self.safe_gc_collect() c2 = conn.cursor() - c2.execute("select 1") - self.assertEqual( - c2.fetchone(), (1,) - ) + self.assertEqual(c2.fetchone(), (1,)) self.assertIsNone(c2.fetchone()) def test_executemany(self): conn = self.test_connection cursor = conn.cursor(pymysql.cursors.Cursor) - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)") - self.assertIsNotNone(m, 'error parse %s') - self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%s, %s)" + ) + self.assertIsNotNone(m, "error parse %s") + self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?") - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)") - self.assertIsNotNone(m, 'error parse %(name)s') - self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)" + ) + self.assertIsNotNone(m, "error parse %(name)s") + self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?") - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)") - self.assertIsNotNone(m, 'error parse %(id_name)s') - self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)" + ) + self.assertIsNotNone(m, "error parse %(id_name)s") + self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?") - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update") - self.assertIsNotNone(m, 'error parse %(id_name)s') - self.assertEqual(m.group(3), ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?') + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update" + ) + self.assertIsNotNone(m, "error parse %(id_name)s") + self.assertEqual( + m.group(3), + " ON duplicate update", + "group 3 not ON duplicate update, bug in RE_INSERT_VALUES?", + ) # https://github.com/PyMySQL/PyMySQL/pull/597 - m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO bloup(foo, bar)VALUES(%s, %s)") + m = pymysql.cursors.RE_INSERT_VALUES.match( + "INSERT INTO bloup(foo, bar)VALUES(%s, %s)" + ) assert m is not None - # cursor._executed must bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)" + # cursor._executed must bee "insert into test (data) + # values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)" # list args data = range(10) cursor.executemany("insert into test (data) values (%s)", data) - self.assertTrue(cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query') + self.assertTrue( + cursor._executed.endswith(b",(7),(8),(9)"), + "execute many with %s not in one query", + ) # dict args - data_dict = [{'data': i} for i in range(10)] + data_dict = [{"data": i} for i in range(10)] cursor.executemany("insert into test (data) values (%(data)s)", data_dict) - self.assertTrue(cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query') + self.assertTrue( + cursor._executed.endswith(b",(7),(8),(9)"), + "execute many with %(data)s not in one query", + ) # %% in column set - cursor.execute("""\ + cursor.execute( + """\ CREATE TABLE percent_test ( `A%` INTEGER, - `B%` INTEGER)""") + `B%` INTEGER)""" + ) try: q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)" self.assertIsNotNone(pymysql.cursors.RE_INSERT_VALUES.match(q)) cursor.executemany(q, [(3, 4), (5, 6)]) - self.assertTrue(cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query") + self.assertTrue( + cursor._executed.endswith(b"(3, 4),(5, 6)"), + "executemany with %% not in one query", + ) finally: cursor.execute("DROP TABLE IF EXISTS percent_test") + + def test_execution_time_limit(self): + # this method is similarly implemented in test_SScursor + + conn = self.test_connection + db_type = self.get_mysql_vendor(conn) + + with conn.cursor(pymysql.cursors.Cursor) as cur: + # MySQL MAX_EXECUTION_TIME takes ms + # MariaDB max_statement_time takes seconds as int/float, introduced in 10.1 + + # this will sleep 0.01 seconds per row + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + + cur.execute(sql) + # unlike SSCursor, Cursor returns a tuple of tuples here + self.assertEqual( + cur.fetchall(), + ( + ("row1", 0), + ("row2", 0), + ("row3", 0), + ("row4", 0), + ("row5", 0), + ), + ) + + if db_type == "mysql": + sql = ( + "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test" + ) + else: + sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test" + cur.execute(sql) + self.assertEqual(cur.fetchone(), ("row1", 0)) + + # this discards the previous unfinished query + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + + if db_type == "mysql": + sql = "SELECT /*+ MAX_EXECUTION_TIME(1) */ data, sleep(1) FROM test" + else: + sql = "SET STATEMENT max_statement_time=0.001 FOR SELECT data, sleep(1) FROM test" + with pytest.raises(pymysql.err.OperationalError) as cm: + # in a buffered cursor this should reliably raise an + # OperationalError + cur.execute(sql) + + if db_type == "mysql": + # this constant was only introduced in MySQL 5.7, not sure + # what was returned before, may have been ER_QUERY_INTERRUPTED + self.assertEqual(cm.value.args[0], ER.QUERY_TIMEOUT) + else: + self.assertEqual(cm.value.args[0], ER.STATEMENT_TIMEOUT) + + # connection should still be fine at this point + cur.execute("SELECT 1") + self.assertEqual(cur.fetchone(), (1,)) + + def test_warnings(self): + con = self.connect() + cur = con.cursor() + cur.execute("DROP TABLE IF EXISTS `no_exists_table`") + self.assertEqual(cur.warning_count, 1) + + cur.execute("SHOW WARNINGS") + w = cur.fetchone() + self.assertEqual(w[1], ER.BAD_TABLE_ERROR) + self.assertIn( + "no_exists_table", + w[2], + ) + + cur.execute("SELECT 1") + self.assertEqual(cur.warning_count, 0) diff --git a/pymysql/tests/test_err.py b/pymysql/tests/test_err.py index 3468d1b10..6eb0f987d 100644 --- a/pymysql/tests/test_err.py +++ b/pymysql/tests/test_err.py @@ -1,21 +1,16 @@ -import unittest2 - +import pytest from pymysql import err -__all__ = ["TestRaiseException"] - - -class TestRaiseException(unittest2.TestCase): - - def test_raise_mysql_exception(self): - data = b"\xff\x15\x04Access denied" - with self.assertRaises(err.OperationalError) as cm: - err.raise_mysql_exception(data) - self.assertEqual(cm.exception.args, (1045, 'Access denied')) +def test_raise_mysql_exception(): + data = b"\xff\x15\x04#28000Access denied" + with pytest.raises(err.OperationalError) as cm: + err.raise_mysql_exception(data) + assert cm.type == err.OperationalError + assert cm.value.args == (1045, "Access denied") - def test_raise_mysql_exception_client_protocol_41(self): - data = b"\xff\x15\x04#28000Access denied" - with self.assertRaises(err.OperationalError) as cm: - err.raise_mysql_exception(data) - self.assertEqual(cm.exception.args, (1045, 'Access denied')) + data = b"\xff\x10\x04Too many connections" + with pytest.raises(err.OperationalError) as cm: + err.raise_mysql_exception(data) + assert cm.type == err.OperationalError + assert cm.value.args == (1040, "Too many connections") diff --git a/pymysql/tests/test_issues.py b/pymysql/tests/test_issues.py index cedd09258..f1fe8dd48 100644 --- a/pymysql/tests/test_issues.py +++ b/pymysql/tests/test_issues.py @@ -1,34 +1,29 @@ import datetime import time import warnings -import sys + +import pytest import pymysql -from pymysql import cursors -from pymysql._compat import text_type from pymysql.tests import base -import unittest2 - -try: - import imp - reload = imp.reload -except AttributeError: - pass - __all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"] + class TestOldIssues(base.PyMySQLTestCase): def test_issue_3(self): - """ undefined methods datetime_or_None, date_or_None """ - conn = self.connections[0] + """undefined methods datetime_or_None, date_or_None""" + conn = self.connect() c = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue3") c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)") try: - c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None)) + c.execute( + "insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", + (None, None, None, None), + ) c.execute("select d from issue3") self.assertEqual(None, c.fetchone()[0]) c.execute("select t from issue3") @@ -36,13 +31,17 @@ def test_issue_3(self): c.execute("select dt from issue3") self.assertEqual(None, c.fetchone()[0]) c.execute("select ts from issue3") - self.assertIn(type(c.fetchone()[0]), (type(None), datetime.datetime), 'expected Python type None or datetime from SQL timestamp') + self.assertIn( + type(c.fetchone()[0]), + (type(None), datetime.datetime), + "expected Python type None or datetime from SQL timestamp", + ) finally: c.execute("drop table issue3") def test_issue_4(self): - """ can't retrieve TIMESTAMP fields """ - conn = self.connections[0] + """can't retrieve TIMESTAMP fields""" + conn = self.connect() c = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") @@ -56,32 +55,34 @@ def test_issue_4(self): c.execute("drop table issue4") def test_issue_5(self): - """ query on information_schema.tables fails """ - con = self.connections[0] + """query on information_schema.tables fails""" + con = self.connect() cur = con.cursor() cur.execute("select * from information_schema.tables") def test_issue_6(self): - """ exception: TypeError: ord() expected a character, but string of length 0 found """ + """exception: TypeError: ord() expected a character, but string of length 0 found""" # ToDo: this test requires access to db 'mysql'. kwargs = self.databases[0].copy() - kwargs['db'] = "mysql" + kwargs["database"] = "mysql" conn = pymysql.connect(**kwargs) c = conn.cursor() c.execute("select * from user") conn.close() def test_issue_8(self): - """ Primary Key and Index error when selecting data """ - conn = self.connections[0] + """Primary Key and Index error when selecting data""" + conn = self.connect() c = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists test") - c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh` -datetime NOT NULL DEFAULT '2015-01-01 00:00:00', `echeance` int(1) NOT NULL + c.execute( + """CREATE TABLE `test` (`station` int NOT NULL DEFAULT '0', `dh` +datetime NOT NULL DEFAULT '2015-01-01 00:00:00', `echeance` int NOT NULL DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY -KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") +KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""" + ) try: self.assertEqual(0, c.execute("SELECT * FROM test")) c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)") @@ -89,16 +90,9 @@ def test_issue_8(self): finally: c.execute("drop table test") - def test_issue_9(self): - """ sets DeprecationWarning in Python 2.6 """ - try: - reload(pymysql) - except DeprecationWarning: - self.fail() - def test_issue_13(self): - """ can't handle large result fields """ - conn = self.connections[0] + """can't handle large result fields""" + conn = self.connect() cur = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") @@ -106,7 +100,7 @@ def test_issue_13(self): try: cur.execute("create table issue13 (t text)") # ticket says 18k - size = 18*1024 + size = 18 * 1024 cur.execute("insert into issue13 (t) values (%s)", ("x" * size,)) cur.execute("select t from issue13") # use assertTrue so that obscenely huge error messages don't print @@ -116,41 +110,47 @@ def test_issue_13(self): cur.execute("drop table issue13") def test_issue_15(self): - """ query should be expanded before perform character encoding """ - conn = self.connections[0] + """query should be expanded before perform character encoding""" + conn = self.connect() c = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue15") c.execute("create table issue15 (t varchar(32))") try: - c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',)) + c.execute("insert into issue15 (t) values (%s)", ("\xe4\xf6\xfc",)) c.execute("select t from issue15") - self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0]) + self.assertEqual("\xe4\xf6\xfc", c.fetchone()[0]) finally: c.execute("drop table issue15") def test_issue_16(self): - """ Patch for string and tuple escaping """ - conn = self.connections[0] + """Patch for string and tuple escaping""" + conn = self.connect() c = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue16") - c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))") + c.execute( + "create table issue16 (name varchar(32) primary key, email varchar(32))" + ) try: - c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')") + c.execute( + "insert into issue16 (name, email) values ('pete', 'floydophone')" + ) c.execute("select email from issue16 where name=%s", ("pete",)) self.assertEqual("floydophone", c.fetchone()[0]) finally: c.execute("drop table issue16") - @unittest2.skip("test_issue_17() requires a custom, legacy MySQL configuration and will not be run.") + @pytest.mark.skip( + "test_issue_17() requires a custom, legacy MySQL configuration and will not be run." + ) def test_issue_17(self): - """could not connect mysql use passwod""" - conn = self.connections[0] + """could not connect mysql use password""" + conn = self.connect() host = self.databases[0]["host"] - db = self.databases[0]["db"] + db = self.databases[0]["database"] c = conn.cursor() # grant access to a table to a user with a password @@ -160,7 +160,10 @@ def test_issue_17(self): c.execute("drop table if exists issue17") c.execute("create table issue17 (x varchar(32) primary key)") c.execute("insert into issue17 (x) values ('hello, world!')") - c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db) + c.execute( + "grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" + % db + ) conn.commit() conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db) @@ -170,6 +173,7 @@ def test_issue_17(self): finally: c.execute("drop table issue17") + class TestNewIssues(base.PyMySQLTestCase): def test_issue_34(self): try: @@ -182,16 +186,17 @@ def test_issue_34(self): def test_issue_33(self): conn = pymysql.connect(charset="utf8", **self.databases[0]) - self.safe_create_table(conn, u'hei\xdfe', - u'create table hei\xdfe (name varchar(32))') + self.safe_create_table( + conn, "hei\xdfe", "create table hei\xdfe (name varchar(32))" + ) c = conn.cursor() - c.execute(u"insert into hei\xdfe (name) values ('Pi\xdfata')") - c.execute(u"select name from hei\xdfe") - self.assertEqual(u"Pi\xdfata", c.fetchone()[0]) + c.execute("insert into hei\xdfe (name) values ('Pi\xdfata')") + c.execute("select name from hei\xdfe") + self.assertEqual("Pi\xdfata", c.fetchone()[0]) - @unittest2.skip("This test requires manual intervention") + @pytest.mark.skip("This test requires manual intervention") def test_issue_35(self): - conn = self.connections[0] + conn = self.connect() c = conn.cursor() print("sudo killall -9 mysqld within the next 10 seconds") try: @@ -237,7 +242,7 @@ def test_issue_36(self): del self.connections[1] def test_issue_37(self): - conn = self.connections[0] + conn = self.connect() c = conn.cursor() self.assertEqual(1, c.execute("SELECT @foo")) self.assertEqual((None,), c.fetchone()) @@ -245,9 +250,9 @@ def test_issue_37(self): c.execute("set @foo = 'bar'") def test_issue_38(self): - conn = self.connections[0] + conn = self.connect() c = conn.cursor() - datum = "a" * 1024 * 1023 # reduced size for most default mysql installs + datum = "a" * 1024 * 1023 # reduced size for most default mysql installs try: with warnings.catch_warnings(): @@ -259,13 +264,13 @@ def test_issue_38(self): c.execute("drop table issue38") def disabled_test_issue_54(self): - conn = self.connections[0] + conn = self.connect() c = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue54") big_sql = "select * from issue54 where " - big_sql += " and ".join("%d=%d" % (i,i) for i in range(0, 100000)) + big_sql += " and ".join("%d=%d" % (i, i) for i in range(0, 100000)) try: c.execute("create table issue54 (id integer primary key)") @@ -275,17 +280,20 @@ def disabled_test_issue_54(self): finally: c.execute("drop table issue54") + class TestGitHubIssues(base.PyMySQLTestCase): def test_issue_66(self): - """ 'Connection' object has no attribute 'insert_id' """ - conn = self.connections[0] + """'Connection' object has no attribute 'insert_id'""" + conn = self.connect() c = conn.cursor() self.assertEqual(0, conn.insert_id()) try: with warnings.catch_warnings(): warnings.filterwarnings("ignore") c.execute("drop table if exists issue66") - c.execute("create table issue66 (id integer primary key auto_increment, x integer)") + c.execute( + "create table issue66 (id integer primary key auto_increment, x integer)" + ) c.execute("insert into issue66 (x) values (1)") c.execute("insert into issue66 (x) values (1)") self.assertEqual(2, conn.insert_id()) @@ -293,8 +301,8 @@ def test_issue_66(self): c.execute("drop table issue66") def test_issue_79(self): - """ Duplicate field overwrites the previous one in the result of DictCursor """ - conn = self.connections[0] + """Duplicate field overwrites the previous one in the result of DictCursor""" + conn = self.connect() c = conn.cursor(pymysql.cursors.DictCursor) with warnings.catch_warnings(): @@ -304,32 +312,34 @@ def test_issue_79(self): c.execute("""CREATE TABLE a (id int, value int)""") c.execute("""CREATE TABLE b (id int, value int)""") - a=(1,11) - b=(1,22) + a = (1, 11) + b = (1, 22) try: c.execute("insert into a values (%s, %s)", a) c.execute("insert into b values (%s, %s)", b) c.execute("SELECT * FROM a inner join b on a.id = b.id") r = c.fetchall()[0] - self.assertEqual(r['id'], 1) - self.assertEqual(r['value'], 11) - self.assertEqual(r['b.value'], 22) + self.assertEqual(r["id"], 1) + self.assertEqual(r["value"], 11) + self.assertEqual(r["b.value"], 22) finally: c.execute("drop table a") c.execute("drop table b") def test_issue_95(self): - """ Leftover trailing OK packet for "CALL my_sp" queries """ - conn = self.connections[0] + """Leftover trailing OK packet for "CALL my_sp" queries""" + conn = self.connect() cur = conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore") cur.execute("DROP PROCEDURE IF EXISTS `foo`") - cur.execute("""CREATE PROCEDURE `foo` () + cur.execute( + """CREATE PROCEDURE `foo` () BEGIN SELECT 1; - END""") + END""" + ) try: cur.execute("""CALL foo()""") cur.execute("""SELECT 1""") @@ -340,7 +350,7 @@ def test_issue_95(self): cur.execute("DROP PROCEDURE IF EXISTS `foo`") def test_issue_114(self): - """ autocommit is not set after reconnecting with ping() """ + """autocommit is not set after reconnecting with ping()""" conn = pymysql.connect(charset="utf8", **self.databases[0]) conn.autocommit(False) c = conn.cursor() @@ -365,59 +375,62 @@ def test_issue_114(self): conn.close() def test_issue_175(self): - """ The number of fields returned by server is read in wrong way """ - conn = self.connections[0] + """The number of fields returned by server is read in wrong way""" + conn = self.connect() cur = conn.cursor() for length in (200, 300): - columns = ', '.join('c{0} integer'.format(i) for i in range(length)) - sql = 'create table test_field_count ({0})'.format(columns) + columns = ", ".join(f"c{i} integer" for i in range(length)) + sql = f"create table test_field_count ({columns})" try: cur.execute(sql) - cur.execute('select * from test_field_count') + cur.execute("select * from test_field_count") assert len(cur.description) == length finally: with warnings.catch_warnings(): warnings.filterwarnings("ignore") - cur.execute('drop table if exists test_field_count') + cur.execute("drop table if exists test_field_count") def test_issue_321(self): - """ Test iterable as query argument. """ + """Test iterable as query argument.""" conn = pymysql.connect(charset="utf8", **self.databases[0]) self.safe_create_table( - conn, "issue321", - "create table issue321 (value_1 varchar(1), value_2 varchar(1))") + conn, + "issue321", + "create table issue321 (value_1 varchar(1), value_2 varchar(1))", + ) sql_insert = "insert into issue321 (value_1, value_2) values (%s, %s)" - sql_dict_insert = ("insert into issue321 (value_1, value_2) " - "values (%(value_1)s, %(value_2)s)") - sql_select = ("select * from issue321 where " - "value_1 in %s and value_2=%s") + sql_dict_insert = ( + "insert into issue321 (value_1, value_2) values (%(value_1)s, %(value_2)s)" + ) + sql_select = "select * from issue321 where value_1 in %s and value_2=%s" data = [ - [(u"a", ), u"\u0430"], - [[u"b"], u"\u0430"], - {"value_1": [[u"c"]], "value_2": u"\u0430"} + [("a",), "\u0430"], + [["b"], "\u0430"], + {"value_1": [["c"]], "value_2": "\u0430"}, ] cur = conn.cursor() self.assertEqual(cur.execute(sql_insert, data[0]), 1) self.assertEqual(cur.execute(sql_insert, data[1]), 1) self.assertEqual(cur.execute(sql_dict_insert, data[2]), 1) - self.assertEqual( - cur.execute(sql_select, [(u"a", u"b", u"c"), u"\u0430"]), 3) - self.assertEqual(cur.fetchone(), (u"a", u"\u0430")) - self.assertEqual(cur.fetchone(), (u"b", u"\u0430")) - self.assertEqual(cur.fetchone(), (u"c", u"\u0430")) + self.assertEqual(cur.execute(sql_select, [("a", "b", "c"), "\u0430"]), 3) + self.assertEqual(cur.fetchone(), ("a", "\u0430")) + self.assertEqual(cur.fetchone(), ("b", "\u0430")) + self.assertEqual(cur.fetchone(), ("c", "\u0430")) def test_issue_364(self): - """ Test mixed unicode/binary arguments in executemany. """ + """Test mixed unicode/binary arguments in executemany.""" conn = pymysql.connect(charset="utf8mb4", **self.databases[0]) self.safe_create_table( - conn, "issue364", + conn, + "issue364", "create table issue364 (value_1 binary(3), value_2 varchar(3)) " - "engine=InnoDB default charset=utf8mb4") + "engine=InnoDB default charset=utf8mb4", + ) sql = "insert into issue364 (value_1, value_2) values (_binary %s, %s)" - usql = u"insert into issue364 (value_1, value_2) values (_binary %s, %s)" - values = [pymysql.Binary(b"\x00\xff\x00"), u"\xe4\xf6\xfc"] + usql = "insert into issue364 (value_1, value_2) values (_binary %s, %s)" + values = [pymysql.Binary(b"\x00\xff\x00"), "\xe4\xf6\xfc"] # test single insert and select cur = conn.cursor() @@ -438,45 +451,44 @@ def test_issue_364(self): cur.executemany(usql, args=(values, values, values)) def test_issue_363(self): - """ Test binary / geometry types. """ + """Test binary / geometry types.""" conn = pymysql.connect(charset="utf8", **self.databases[0]) self.safe_create_table( - conn, "issue363", + conn, + "issue363", "CREATE TABLE issue363 ( " "id INTEGER PRIMARY KEY, geom LINESTRING NOT NULL /*!80003 SRID 0 */, " "SPATIAL KEY geom (geom)) " - "ENGINE=MyISAM") + "ENGINE=MyISAM", + ) cur = conn.cursor() - # From MySQL 5.7, ST_GeomFromText is added and GeomFromText is deprecated. - if self.mysql_server_is(conn, (5, 7, 0)): - geom_from_text = "ST_GeomFromText" - geom_as_text = "ST_AsText" - geom_as_bin = "ST_AsBinary" - else: - geom_from_text = "GeomFromText" - geom_as_text = "AsText" - geom_as_bin = "AsBinary" - query = ("INSERT INTO issue363 (id, geom) VALUES" - "(1998, %s('LINESTRING(1.1 1.1,2.2 2.2)'))" % geom_from_text) + query = ( + "INSERT INTO issue363 (id, geom) VALUES" + "(1998, ST_GeomFromText('LINESTRING(1.1 1.1,2.2 2.2)'))" + ) cur.execute(query) # select WKT - query = "SELECT %s(geom) FROM issue363" % geom_as_text + query = "SELECT ST_AsText(geom) FROM issue363" cur.execute(query) row = cur.fetchone() - self.assertEqual(row, ("LINESTRING(1.1 1.1,2.2 2.2)", )) + self.assertEqual(row, ("LINESTRING(1.1 1.1,2.2 2.2)",)) # select WKB - query = "SELECT %s(geom) FROM issue363" % geom_as_bin + query = "SELECT ST_AsBinary(geom) FROM issue363" cur.execute(query) row = cur.fetchone() - self.assertEqual(row, - (b"\x01\x02\x00\x00\x00\x02\x00\x00\x00" - b"\x9a\x99\x99\x99\x99\x99\xf1?" - b"\x9a\x99\x99\x99\x99\x99\xf1?" - b"\x9a\x99\x99\x99\x99\x99\x01@" - b"\x9a\x99\x99\x99\x99\x99\x01@", )) + self.assertEqual( + row, + ( + b"\x01\x02\x00\x00\x00\x02\x00\x00\x00" + b"\x9a\x99\x99\x99\x99\x99\xf1?" + b"\x9a\x99\x99\x99\x99\x99\xf1?" + b"\x9a\x99\x99\x99\x99\x99\x01@" + b"\x9a\x99\x99\x99\x99\x99\x01@", + ), + ) # select internal binary cur.execute("SELECT geom FROM issue363") @@ -484,28 +496,3 @@ def test_issue_363(self): # don't assert the exact internal binary value, as it could # vary across implementations self.assertTrue(isinstance(row[0], bytes)) - - def test_issue_491(self): - """ Test warning propagation """ - conn = pymysql.connect(charset="utf8", **self.databases[0]) - - with warnings.catch_warnings(): - # Ignore all warnings other than pymysql generated ones - warnings.simplefilter("ignore") - warnings.simplefilter("error", category=pymysql.Warning) - - # verify for both buffered and unbuffered cursor types - for cursor_class in (cursors.Cursor, cursors.SSCursor): - c = conn.cursor(cursor_class) - try: - c.execute("SELECT CAST('124b' AS SIGNED)") - c.fetchall() - except pymysql.Warning as e: - # Warnings should have errorcode and string message, just like exceptions - self.assertEqual(len(e.args), 2) - self.assertEqual(e.args[0], 1292) - self.assertTrue(isinstance(e.args[1], text_type)) - else: - self.fail("Should raise Warning") - finally: - c.close() diff --git a/pymysql/tests/test_load_local.py b/pymysql/tests/test_load_local.py index 85fd94ea5..509221420 100644 --- a/pymysql/tests/test_load_local.py +++ b/pymysql/tests/test_load_local.py @@ -1,8 +1,8 @@ -from pymysql import cursors, OperationalError, Warning +from pymysql import cursors, OperationalError +from pymysql.constants import ER from pymysql.tests import base import os -import warnings __all__ = ["TestLoadLocal"] @@ -10,15 +10,17 @@ class TestLoadLocal(base.PyMySQLTestCase): def test_no_file(self): """Test load local infile when the file does not exist""" - conn = self.connections[0] + conn = self.connect() c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") try: self.assertRaises( OperationalError, c.execute, - ("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE " - "test_load_local fields terminated by ','") + ( + "LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE " + "test_load_local fields terminated by ','" + ), ) finally: c.execute("DROP TABLE test_load_local") @@ -26,16 +28,16 @@ def test_no_file(self): def test_load_file(self): """Test load local infile with a valid file""" - conn = self.connections[0] + conn = self.connect() c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'data', - 'load_local_data.txt') + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "data", "load_local_data.txt" + ) try: c.execute( - ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + - "test_load_local FIELDS TERMINATED BY ','").format(filename) + f"LOAD DATA LOCAL INFILE '{filename}' INTO TABLE test_load_local" + + " FIELDS TERMINATED BY ','" ) c.execute("SELECT COUNT(*) FROM test_load_local") self.assertEqual(22749, c.fetchone()[0]) @@ -44,16 +46,16 @@ def test_load_file(self): def test_unbuffered_load_file(self): """Test unbuffered load local infile with a valid file""" - conn = self.connections[0] + conn = self.connect() c = conn.cursor(cursors.SSCursor) c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'data', - 'load_local_data.txt') + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "data", "load_local_data.txt" + ) try: c.execute( - ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + - "test_load_local FIELDS TERMINATED BY ','").format(filename) + f"LOAD DATA LOCAL INFILE '{filename}' INTO TABLE test_load_local" + + " FIELDS TERMINATED BY ','" ) c.execute("SELECT COUNT(*) FROM test_load_local") self.assertEqual(22749, c.fetchone()[0]) @@ -66,23 +68,31 @@ def test_unbuffered_load_file(self): def test_load_warnings(self): """Test load local infile produces the appropriate warnings""" - conn = self.connections[0] + conn = self.connect() c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'data', - 'load_local_warn_data.txt') + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "data", + "load_local_warn_data.txt", + ) try: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - c.execute( - ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + - "test_load_local FIELDS TERMINATED BY ','").format(filename) - ) - self.assertEqual(w[0].category, Warning) - expected_message = "Incorrect integer value" - if expected_message not in str(w[-1].message): - self.fail("%r not in %r" % (expected_message, w[-1].message)) + c.execute( + ( + "LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + + "test_load_local FIELDS TERMINATED BY ','" + ).format(filename) + ) + self.assertEqual(1, c.warning_count) + + c.execute("SHOW WARNINGS") + w = c.fetchone() + + self.assertEqual(ER.TRUNCATED_WRONG_VALUE_FOR_FIELD, w[1]) + self.assertIn( + "incorrect integer value", + w[2].lower(), + ) finally: c.execute("DROP TABLE test_load_local") c.close() @@ -90,4 +100,5 @@ def test_load_warnings(self): if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/test_nextset.py b/pymysql/tests/test_nextset.py index 998441072..a10f8d5b7 100644 --- a/pymysql/tests/test_nextset.py +++ b/pymysql/tests/test_nextset.py @@ -1,17 +1,16 @@ -import unittest2 +import pytest import pymysql -from pymysql import util from pymysql.tests import base from pymysql.constants import CLIENT class TestNextset(base.PyMySQLTestCase): - def test_nextset(self): con = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) cur = con.cursor() cur.execute("SELECT 1; SELECT 2;") self.assertEqual([(1,)], list(cur)) @@ -39,7 +38,7 @@ def test_nextset_error(self): self.assertEqual([(i,)], list(cur)) with self.assertRaises(pymysql.ProgrammingError): cur.nextset() - self.assertEqual((), cur.fetchall()) + self.assertEqual([], cur.fetchall()) def test_ok_and_next(self): cur = self.connect(client_flag=CLIENT.MULTI_STATEMENTS).cursor() @@ -50,7 +49,7 @@ def test_ok_and_next(self): self.assertEqual([(2,)], list(cur)) self.assertFalse(bool(cur.nextset())) - @unittest2.expectedFailure + @pytest.mark.xfail def test_multi_cursor(self): con = self.connect(client_flag=CLIENT.MULTI_STATEMENTS) cur1 = con.cursor() @@ -71,14 +70,14 @@ def test_multi_cursor(self): def test_multi_statement_warnings(self): con = self.connect( init_command='SELECT "bar"; SELECT "baz"', - client_flag=CLIENT.MULTI_STATEMENTS) + client_flag=CLIENT.MULTI_STATEMENTS, + ) cursor = con.cursor() try: - cursor.execute('DROP TABLE IF EXISTS a; ' - 'DROP TABLE IF EXISTS b;') + cursor.execute("DROP TABLE IF EXISTS a; DROP TABLE IF EXISTS b;") except TypeError: self.fail() - #TODO: How about SSCursor and nextset? + # TODO: How about SSCursor and nextset? # It's very hard to implement correctly... diff --git a/pymysql/tests/test_optionfile.py b/pymysql/tests/test_optionfile.py index 3ee519e2b..d13553dda 100644 --- a/pymysql/tests/test_optionfile.py +++ b/pymysql/tests/test_optionfile.py @@ -1,33 +1,24 @@ -from pymysql.optionfile import Parser +from io import StringIO from unittest import TestCase -from pymysql._compat import PY2 - -try: - from cStringIO import StringIO -except ImportError: - from io import StringIO +from pymysql.optionfile import Parser -__all__ = ['TestParser'] +__all__ = ["TestParser"] -_cfg_file = (r""" +_cfg_file = r""" [default] string = foo quoted = "bar" single_quoted = 'foobar' skip-slave-start -""") +""" class TestParser(TestCase): - def test_string(self): parser = Parser() - if PY2: - parser.readfp(StringIO(_cfg_file)) - else: - parser.read_file(StringIO(_cfg_file)) + parser.read_file(StringIO(_cfg_file)) self.assertEqual(parser.get("default", "string"), "foo") self.assertEqual(parser.get("default", "quoted"), "bar") - self.assertEqual(parser.get("default", "single_quoted"), "foobar") + self.assertEqual(parser.get("default", "single-quoted"), "foobar") diff --git a/pymysql/tests/thirdparty/__init__.py b/pymysql/tests/thirdparty/__init__.py index 6d59e1127..d5f053711 100644 --- a/pymysql/tests/thirdparty/__init__.py +++ b/pymysql/tests/thirdparty/__init__.py @@ -1,8 +1,6 @@ from .test_MySQLdb import * if __name__ == "__main__": - try: - import unittest2 as unittest - except ImportError: - import unittest + import unittest + unittest.main() diff --git a/pymysql/tests/thirdparty/test_MySQLdb/__init__.py b/pymysql/tests/thirdparty/test_MySQLdb/__init__.py index e4237c69a..501bfd2db 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/__init__.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/__init__.py @@ -1,7 +1,6 @@ -from .test_MySQLdb_capabilities import test_MySQLdb as test_capabilities from .test_MySQLdb_nonstandard import * -from .test_MySQLdb_dbapi20 import test_MySQLdb as test_dbapi2 if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py index bcf9eecbf..bb47cc5f6 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py @@ -4,17 +4,11 @@ Adapted from a script by M-A Lemburg. """ -import sys from time import time -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -PY2 = sys.version_info[0] == 2 class DatabaseTest(unittest.TestCase): - db_module = None connect_args = () connect_kwargs = dict(use_unicode=True, charset="utf8mb4", binary_prefix=True) @@ -26,11 +20,8 @@ def setUp(self): db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) self.connection = db self.cursor = db.cursor() - self.BLOBText = ''.join([chr(i) for i in range(256)] * 100); - if PY2: - self.BLOBUText = unicode().join(unichr(i) for i in range(16834)) - else: - self.BLOBUText = "".join(chr(i) for i in range(16834)) + self.BLOBText = "".join([chr(i) for i in range(256)] * 100) + self.BLOBUText = "".join(chr(i) for i in range(16834)) data = bytearray(range(256)) * 16 self.BLOBBinary = self.db_module.Binary(data) @@ -39,17 +30,22 @@ def setUp(self): def tearDown(self): if self.leak_test: import gc + del self.cursor orphans = gc.collect() - self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans) + self.assertFalse( + orphans, "%d orphaned objects found after deleting cursor" % orphans + ) del self.connection orphans = gc.collect() - self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans) + self.assertFalse( + orphans, "%d orphaned objects found after deleting connection" % orphans + ) def table_exists(self, name): try: - self.cursor.execute('select * from %s where 1=0' % name) + self.cursor.execute("select * from %s where 1=0" % name) except Exception: return False else: @@ -61,41 +57,41 @@ def quote_identifier(self, ident): def new_table_name(self): i = id(self.cursor) while True: - name = self.quote_identifier('tb%08x' % i) + name = self.quote_identifier("tb%08x" % i) if not self.table_exists(name): return name i = i + 1 def create_table(self, columndefs): + """ + Create a table using a list of column definitions given in columndefs. - """ Create a table using a list of column definitions given in - columndefs. - - generator must be a function taking arguments (row_number, - col_number) returning a suitable data object for insertion - into the table. - + generator must be a function taking arguments (row_number, + col_number) returning a suitable data object for insertion + into the table. """ self.table = self.new_table_name() - self.cursor.execute('CREATE TABLE %s (%s) %s' % - (self.table, - ',\n'.join(columndefs), - self.create_table_extra)) + self.cursor.execute( + "CREATE TABLE %s (%s) %s" + % (self.table, ",\n".join(columndefs), self.create_table_extra) + ) def check_data_integrity(self, columndefs, generator): # insert self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows) + ] if self.debug: print(data) self.cursor.executemany(insert_statement, data) self.connection.commit() # verify - self.cursor.execute('select * from %s' % self.table) + self.cursor.execute("select * from %s" % self.table) l = self.cursor.fetchall() if self.debug: print(l) @@ -103,62 +99,74 @@ def check_data_integrity(self, columndefs, generator): try: for i in range(self.rows): for j in range(len(columndefs)): - self.assertEqual(l[i][j], generator(i,j)) + self.assertEqual(l[i][j], generator(i, j)) finally: if not self.debug: - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_transactions(self): - columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + columndefs = ("col1 INT", "col2 VARCHAR(255)") + def generator(row, col): - if col == 0: return row - else: return ('%i' % (row%10))*255 + if col == 0: + return row + else: + return ("%i" % (row % 10)) * 255 + self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) # verify self.connection.commit() - self.cursor.execute('select * from %s' % self.table) + self.cursor.execute("select * from %s" % self.table) l = self.cursor.fetchall() self.assertEqual(len(l), self.rows) for i in range(self.rows): for j in range(len(columndefs)): - self.assertEqual(l[i][j], generator(i,j)) - delete_statement = 'delete from %s where col1=%%s' % self.table + self.assertEqual(l[i][j], generator(i, j)) + delete_statement = "delete from %s where col1=%%s" % self.table self.cursor.execute(delete_statement, (0,)) - self.cursor.execute('select col1 from %s where col1=%s' % \ - (self.table, 0)) + self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0)) l = self.cursor.fetchall() self.assertFalse(l, "DELETE didn't work") self.connection.rollback() - self.cursor.execute('select col1 from %s where col1=%s' % \ - (self.table, 0)) + self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0)) l = self.cursor.fetchall() self.assertTrue(len(l) == 1, "ROLLBACK didn't work") - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_truncation(self): - columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + columndefs = ("col1 INT", "col2 VARCHAR(255)") + def generator(row, col): - if col == 0: return row - else: return ('%i' % (row%10))*((255-self.rows//2)+row) + if col == 0: + return row + else: + return ("%i" % (row % 10)) * ((255 - self.rows // 2) + row) + self.create_table(columndefs) - insert_statement = ('INSERT INTO %s VALUES (%s)' % - (self.table, - ','.join(['%s'] * len(columndefs)))) + insert_statement = "INSERT INTO %s VALUES (%s)" % ( + self.table, + ",".join(["%s"] * len(columndefs)), + ) try: - self.cursor.execute(insert_statement, (0, '0'*256)) + self.cursor.execute(insert_statement, (0, "0" * 256)) except Warning: - if self.debug: print(self.cursor.messages) + if self.debug: + print(self.cursor.messages) except self.connection.DataError: pass else: - self.fail("Over-long column did not generate warnings/exception with single insert") + self.fail( + "Over-long column did not generate warnings/exception with single insert" + ) self.connection.rollback() @@ -166,132 +174,136 @@ def generator(row, col): for i in range(self.rows): data = [] for j in range(len(columndefs)): - data.append(generator(i,j)) - self.cursor.execute(insert_statement,tuple(data)) + data.append(generator(i, j)) + self.cursor.execute(insert_statement, tuple(data)) except Warning: - if self.debug: print(self.cursor.messages) + if self.debug: + print(self.cursor.messages) except self.connection.DataError: pass else: - self.fail("Over-long columns did not generate warnings/exception with execute()") + self.fail( + "Over-long columns did not generate warnings/exception with execute()" + ) self.connection.rollback() try: - data = [ [ generator(i,j) for j in range(len(columndefs)) ] - for i in range(self.rows) ] + data = [ + [generator(i, j) for j in range(len(columndefs))] + for i in range(self.rows) + ] self.cursor.executemany(insert_statement, data) except Warning: - if self.debug: print(self.cursor.messages) + if self.debug: + print(self.cursor.messages) except self.connection.DataError: pass else: - self.fail("Over-long columns did not generate warnings/exception with executemany()") + self.fail( + "Over-long columns did not generate warnings/exception with executemany()" + ) self.connection.rollback() - self.cursor.execute('drop table %s' % (self.table)) + self.cursor.execute("drop table %s" % (self.table)) def test_CHAR(self): # Character data - def generator(row,col): - return ('%i' % ((row+col) % 10)) * 255 - self.check_data_integrity( - ('col1 char(255)','col2 char(255)'), - generator) + def generator(row, col): + return ("%i" % ((row + col) % 10)) * 255 + + self.check_data_integrity(("col1 char(255)", "col2 char(255)"), generator) def test_INT(self): # Number data - def generator(row,col): - return row*row - self.check_data_integrity( - ('col1 INT',), - generator) + def generator(row, col): + return row * row + + self.check_data_integrity(("col1 INT",), generator) def test_DECIMAL(self): # DECIMAL - def generator(row,col): + def generator(row, col): from decimal import Decimal + return Decimal("%d.%02d" % (row, col)) - self.check_data_integrity( - ('col1 DECIMAL(5,2)',), - generator) + + self.check_data_integrity(("col1 DECIMAL(5,2)",), generator) def test_DATE(self): ticks = time() - def generator(row,col): - return self.db_module.DateFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 DATE',), - generator) + + def generator(row, col): + return self.db_module.DateFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 DATE",), generator) def test_TIME(self): ticks = time() - def generator(row,col): - return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIME',), - generator) + + def generator(row, col): + return self.db_module.TimeFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 TIME",), generator) def test_DATETIME(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 DATETIME',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 DATETIME",), generator) def test_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313) + + self.check_data_integrity(("col1 TIMESTAMP",), generator) def test_fractional_TIMESTAMP(self): ticks = time() - def generator(row,col): - return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) - self.check_data_integrity( - ('col1 TIMESTAMP',), - generator) + + def generator(row, col): + return self.db_module.TimestampFromTicks( + ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0 + ) + + self.check_data_integrity(("col1 TIMESTAMP",), generator) def test_LONG(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBUText # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col1 INT', 'col2 LONG'), - generator) + return self.BLOBUText # 'BLOB Text ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 LONG"), generator) def test_TEXT(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBUText[:5192] # 'BLOB Text ' * 1024 - self.check_data_integrity( - ('col1 INT', 'col2 TEXT'), - generator) + return self.BLOBUText[:5192] # 'BLOB Text ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 TEXT"), generator) def test_LONG_BYTE(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 LONG BYTE'), - generator) + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 LONG BYTE"), generator) def test_BLOB(self): - def generator(row,col): + def generator(row, col): if col == 0: return row else: - return self.BLOBBinary # 'BLOB\000Binary ' * 1024 - self.check_data_integrity( - ('col1 INT','col2 BLOB'), - generator) + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + + self.check_data_integrity(("col1 INT", "col2 BLOB"), generator) diff --git a/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py index 3cbf2263d..fff14b86f 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py @@ -1,4 +1,4 @@ -''' Python DB API 2.0 driver compliance unit test suite. +""" Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -8,18 +8,14 @@ this is turning out to be a thoroughly unwholesome unit test." -- Ian Bicking -''' +""" -__rcs_id__ = '$Id$' -__version__ = '$Revision$'[11:-2] -__author__ = 'Stuart Bishop ' - -try: - import unittest2 as unittest -except ImportError: - import unittest +__rcs_id__ = "$Id$" +__version__ = "$Revision$"[11:-2] +__author__ = "Stuart Bishop " import time +import unittest # $Log$ # Revision 1.1.2.1 2006/02/25 03:44:32 adustman @@ -55,9 +51,9 @@ # - Now a subclass of TestCase, to avoid requiring the driver stub # to use multiple inheritance # - Reversed the polarity of buggy test in test_description -# - Test exception heirarchy correctly +# - Test exception hierarchy correctly # - self.populate is now self._populate(), so if a driver stub -# overrides self.ddl1 this change propogates +# overrides self.ddl1 this change propagates # - VARCHAR columns now have a width, which will hopefully make the # DDL even more portible (this will be reversed if it causes more problems) # - cursor.rowcount being checked after various execute and fetchXXX methods @@ -67,65 +63,66 @@ # - Fix bugs in test_setoutputsize_basic and test_setinputsizes # + class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. - This implementation tests Gadfly, but the TestCase - is structured so that other self.drivers can subclass this - test case to ensure compiliance with the DB-API. It is - expected that this TestCase may be expanded in the future - if ambiguities or edge conditions are discovered. + """Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. - The 'Optional Extensions' are not yet being tested. + The 'Optional Extensions' are not yet being tested. - self.drivers should subclass this test, overriding setUp, tearDown, - self.driver, connect_args and connect_kw_args. Class specification - should be as follows: + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: - import dbapi20 - class mytest(dbapi20.DatabaseAPI20Test): - [...] + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] - Don't 'import DatabaseAPI20Test from dbapi20', or you will - confuse the unit tester - just 'import dbapi20'. - ''' + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + """ # The self.driver module. This should be the module where the 'connect' # method is to be found driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = "dbapi20test_" # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix + ddl1 = "create table %sbooze (name varchar(20))" % table_prefix + ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix + xddl1 = "drop table %sbooze" % table_prefix + xddl2 = "drop table %sbarflys" % table_prefix - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + lowerfunc = "lower" # Name of stored procedure to convert string->lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self,cursor): + def executeDDL1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self,cursor): + def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - ''' self.drivers should override this method to perform required setup - if any is necessary, such as creating the database. - ''' + """self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + """ pass def tearDown(self): - ''' self.drivers should override this method to perform required cleanup - if any is necessary, such as deleting the test database. - The default drops the tables that may be created. - ''' + """self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + """ con = self._connect() try: cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): + for ddl in (self.xddl1, self.xddl2): try: cur.execute(ddl) con.commit() @@ -138,9 +135,7 @@ def tearDown(self): def _connect(self): try: - return self.driver.connect( - *self.connect_args,**self.connect_kw_args - ) + return self.driver.connect(*self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") @@ -153,7 +148,7 @@ def test_apilevel(self): # Must exist apilevel = self.driver.apilevel # Must equal 2.0 - self.assertEqual(apilevel,'2.0') + self.assertEqual(apilevel, "2.0") except AttributeError: self.fail("Driver doesn't define apilevel") @@ -162,7 +157,7 @@ def test_threadsafety(self): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) + self.assertTrue(threadsafety in (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -171,38 +166,24 @@ def test_paramstyle(self): # Must exist paramstyle = self.driver.paramstyle # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) + self.assertTrue( + paramstyle in ("qmark", "numeric", "named", "format", "pyformat") + ) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the - # defined heirarchy. - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) + # defined hierarchy. + self.assertTrue(issubclass(self.driver.Warning, Exception)) + self.assertTrue(issubclass(self.driver.Error, Exception)) + self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION @@ -223,7 +204,6 @@ def test_ExceptionsAsConnectionAttributes(self): self.assertTrue(con.ProgrammingError is drv.ProgrammingError) self.assertTrue(con.NotSupportedError is drv.NotSupportedError) - def test_commit(self): con = self._connect() try: @@ -236,7 +216,7 @@ def test_rollback(self): con = self._connect() # If rollback is defined, it should either work or throw # the documented exception - if hasattr(con,'rollback'): + if hasattr(con, "rollback"): try: con.rollback() except self.driver.NotSupportedError: @@ -245,7 +225,7 @@ def test_rollback(self): def test_cursor(self): con = self._connect() try: - cur = con.cursor() + con.cursor() finally: con.close() @@ -257,14 +237,14 @@ def test_cursor_isolation(self): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) + cur1.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], "Victoria Bitter") finally: con.close() @@ -273,31 +253,41 @@ def test_description(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.assertEqual( + cur.description, + None, + "cursor.description should be none after executing a " + "statement that can return no rows (such as DDL)", + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + len(cur.description), 1, "cursor.description describes too many columns" + ) + self.assertEqual( + len(cur.description[0]), + 7, + "cursor.description[x] tuples must have 7 elements", + ) + self.assertEqual( + cur.description[0][0].lower(), + "name", + "cursor.description[x][0] must return column name", + ) + self.assertEqual( + cur.description[0][1], + self.driver.STRING, + "cursor.description[x][1] must return column type. Got %r" + % cur.description[0][1], + ) # Make sure self.description gets reset self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.assertEqual( + cur.description, + None, + "cursor.description not being set to None when executing " + "no-result statements (eg. DDL)", + ) finally: con.close() @@ -306,47 +296,49 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount should be -1 after executing no-result statements", + ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number or rows inserted, or " + "set to -1 after executing an insert statement", + ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount not being reset to -1 after executing " + "no-result statements", + ) finally: con.close() - lower_func = 'lower' + lower_func = "lower" + def test_callproc(self): con = self._connect() try: cur = con.cursor() - if self.lower_func and hasattr(cur,'callproc'): - r = cur.callproc(self.lower_func,('FOO',)) - self.assertEqual(len(r),1) - self.assertEqual(r[0],'FOO') + if self.lower_func and hasattr(cur, "callproc"): + r = cur.callproc(self.lower_func, ("FOO",)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], "FOO") r = cur.fetchall() - self.assertEqual(len(r),1,'callproc produced no result set') - self.assertEqual(len(r[0]),1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0],'foo', - 'callproc produced invalid results' - ) + self.assertEqual(len(r), 1, "callproc produced no result set") + self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") + self.assertEqual(r[0][0], "foo", "callproc produced invalid results") finally: con.close() @@ -359,14 +351,14 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) + self.assertRaises(self.driver.Error, self.executeDDL1, cur) # connection.commit should raise an Error if called after connection' # closed.' - self.assertRaises(self.driver.Error,con.commit) + self.assertRaises(self.driver.Error, con.commit) # connection.close should raise an Error if called more than once - self.assertRaises(self.driver.Error,con.close) + self.assertRaises(self.driver.Error, con.close) def test_execute(self): con = self._connect() @@ -376,105 +368,99 @@ def test_execute(self): finally: con.close() - def _paraminsert(self,cur): + def _paraminsert(self, cur): self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1)) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue(cur.rowcount in (-1, 1)) - if self.driver.paramstyle == 'qmark': + if self.driver.paramstyle == "qmark": cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "numeric": cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "named": cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, + {"beer": "Cooper's"}, + ) + elif self.driver.paramstyle == "format": cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "pyformat": cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) + "insert into %sbooze values (%%(beer)s)" % self.table_prefix, + {"beer": "Cooper's"}, + ) else: - self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) + self.fail("Invalid paramstyle") + self.assertTrue(cur.rowcount in (-1, 1)) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows") + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual( + beers[0], + "Cooper's", + "cursor.fetchall retrieved incorrect data, or data inserted incorrectly", + ) + self.assertEqual( + beers[1], + "Victoria Bitter", + "cursor.fetchall retrieved incorrect data, or data inserted incorrectly", + ) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] - if self.driver.paramstyle == 'qmark': + largs = [("Cooper's",), ("Boag's",)] + margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] + if self.driver.paramstyle == "qmark": cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "numeric": cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "named": cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, - margs - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, margs + ) + elif self.driver.paramstyle == "format": cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "pyformat": cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), - margs - ) - else: - self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount + "insert into %sbooze values (%%(beer)s)" % (self.table_prefix), + margs, ) - cur.execute('select name from %sbooze' % self.table_prefix) + else: + self.fail("Unknown paramstyle") + self.assertTrue( + cur.rowcount in (-1, 2), + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount, + ) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] + self.assertEqual( + len(res), 2, "cursor.fetchall retrieved incorrect number of rows" + ) + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + self.assertEqual(beers[0], "Boag's", "incorrect data retrieved") + self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved") finally: con.close() @@ -485,59 +471,62 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows + # executing a query that cannot return rows self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) + # executing a query that cannot return rows + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if no more rows available", + ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() samples = [ - 'Carlton Cold', - 'Carlton Draft', - 'Mountain Goat', - 'Redback', - 'Victoria Bitter', - 'XXXX' - ] + "Carlton Cold", + "Carlton Draft", + "Mountain Goat", + "Redback", + "Victoria Bitter", + "XXXX", + ] def _populate(self): - ''' Return a list of sql commands to setup the DB for the fetch - tests. - ''' + """Return a list of sql commands to setup the DB for the fetch + tests. + """ populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] + "insert into %sbooze values ('%s')" % (self.table_prefix, s) + for s in self.samples + ] return populate def test_fetchmany(self): @@ -546,78 +535,88 @@ def test_fetchmany(self): cur = con.cursor() # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' + self.assertEqual( + len(r), + 1, + "cursor.fetchmany retrieved incorrect number of rows, " + "default of arraysize is one.", + ) + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual( + len(r), 3, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual( + len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence after " + "results are exhausted", ) - self.assertTrue(cur.rowcount in (-1,6)) + self.assertTrue(cur.rowcount in (-1, 6)) # Same as above, using cursor.arraysize - cur.arraysize=4 - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) - - cur.arraysize=6 - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) + cur.arraysize = 4 + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual( + len(r), 4, "cursor.arraysize not being honoured by fetchmany" + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + self.assertTrue(cur.rowcount in (-1, 6)) + + cur.arraysize = 6 + cur.execute("select name from %sbooze" % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' + for i in range(0, 6): + self.assertEqual( + rows[i], + self.samples[i], + "incorrect data retrieved by cursor.fetchmany", ) - self.assertTrue(cur.rowcount in (-1,6)) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, 6)) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbarflys" % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence if " + "query retrieved no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) finally: con.close() @@ -637,36 +636,41 @@ def test_fetchall(self): # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) + self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) finally: con.close() @@ -679,74 +683,74 @@ def test_mixedfetch(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() + cur.execute("select name from %sbooze" % self.table_prefix) + rows1 = cur.fetchone() rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() + rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual( + len(rows23), 2, "fetchmany returned incorrect number of rows" + ) + self.assertEqual( + len(rows56), 2, "fetchall returned incorrect number of rows" + ) rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) + rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) + rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "incorrect data retrieved or inserted" + ) finally: con.close() - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' - raise NotImplementedError('Helper not implemented') - #sql=""" + def help_nextset_setUp(self, cur): + """Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + """ + raise NotImplementedError("Helper not implemented") + # sql=""" # create procedure deleteme as # begin # select count(*) from booze # select name from booze # end - #""" - #cur.execute(sql) + # """ + # cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' - raise NotImplementedError('Helper not implemented') - #cur.execute("drop procedure deleteme") + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" + raise NotImplementedError("Helper not implemented") + # cur.execute("drop procedure deleteme") def test_nextset(self): con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() - assert s == None,'No more return sets, should return None' + s = cur.nextset() + assert s == None, "No more return sets, should return None" finally: self.help_nextset_tearDown(cur) @@ -754,16 +758,16 @@ def test_nextset(self): con.close() def test_nextset(self): - raise NotImplementedError('Drivers need to override this test') + raise NotImplementedError("Drivers need to override this test") def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue( + hasattr(cur, "arraysize"), "cursor.arraysize must be defined" + ) finally: con.close() @@ -771,8 +775,8 @@ def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works finally: con.close() @@ -782,74 +786,68 @@ def test_setoutputsize_basic(self): try: cur = con.cursor() cur.setoutputsize(1000) - cur.setoutputsize(2000,0) - self._paraminsert(cur) # Make sure the cursor still works + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): - # Real test for setoutputsize is driver dependant - raise NotImplementedError('Driver need to override this test') + # Real test for setoutputsize is driver dependent + raise NotImplementedError("Driver need to override this test") def test_None(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("insert into %sbooze values (NULL)" % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertEqual(r[0][0], None, "NULL value not returned as None") finally: con.close() def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + self.driver.Date(2002, 12, 25) + self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(d1),str(d2)) def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + self.driver.Time(13, 45, 30) + self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) - t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) + self.driver.Timestamp(2002, 12, 25, 13, 45, 30) + self.driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Binary(self): - b = self.driver.Binary(b'Something') - b = self.driver.Binary(b'') + self.driver.Binary(b"Something") + self.driver.Binary(b"") def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) + self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined") def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "BINARY"), "module.BINARY must be defined." + ) def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined." + ) def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined." + ) def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) + self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py index 0fc5e8316..6a2894a5a 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py @@ -1,22 +1,24 @@ from . import capabilities -try: - import unittest2 as unittest -except ImportError: - import unittest import pymysql from pymysql.tests import base import warnings -warnings.filterwarnings('error') +warnings.filterwarnings("error") -class test_MySQLdb(capabilities.DatabaseTest): +class test_MySQLdb(capabilities.DatabaseTest): db_module = pymysql connect_args = () connect_kwargs = base.PyMySQLTestCase.databases[0].copy() - connect_kwargs.update(dict(read_default_file='~/.my.cnf', - use_unicode=True, binary_prefix=True, - charset='utf8mb4', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + connect_kwargs.update( + dict( + read_default_file="~/.my.cnf", + use_unicode=True, + binary_prefix=True, + charset="utf8mb4", + sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL", + ) + ) leak_test = False @@ -25,64 +27,70 @@ def quote_identifier(self, ident): def test_TIME(self): from datetime import timedelta - def generator(row,col): - return timedelta(0, row*8000) - self.check_data_integrity( - ('col1 TIME',), - generator) + + def generator(row, col): + return timedelta(0, row * 8000) + + self.check_data_integrity(("col1 TIME",), generator) def test_TINYINT(self): # Number data - def generator(row,col): - v = (row*row) % 256 + def generator(row, col): + v = (row * row) % 256 if v > 127: - v = v-256 + v = v - 256 return v - self.check_data_integrity( - ('col1 TINYINT',), - generator) + + self.check_data_integrity(("col1 TINYINT",), generator) def test_stored_procedures(self): db = self.connection c = self.cursor try: - self.create_table(('pos INT', 'tree CHAR(20)')) - c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, - list(enumerate('ash birch cedar larch pine'.split()))) + self.create_table(("pos INT", "tree CHAR(20)")) + c.executemany( + "INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, + list(enumerate("ash birch cedar larch pine".split())), + ) db.commit() - c.execute(""" + c.execute( + """ CREATE PROCEDURE test_sp(IN t VARCHAR(255)) BEGIN SELECT pos FROM %s WHERE tree = t; END - """ % self.table) + """ + % self.table + ) db.commit() - c.callproc('test_sp', ('larch',)) + c.callproc("test_sp", ("larch",)) rows = c.fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], 3) c.nextset() finally: c.execute("DROP PROCEDURE IF EXISTS test_sp") - c.execute('drop table %s' % (self.table)) + c.execute("drop table %s" % (self.table)) def test_small_CHAR(self): # Character data - def generator(row,col): - i = ((row+1)*(col+1)+62)%256 - if i == 62: return '' - if i == 63: return None + def generator(row, col): + i = ((row + 1) * (col + 1) + 62) % 256 + if i == 62: + return "" + if i == 63: + return None return chr(i) - self.check_data_integrity( - ('col1 char(1)','col2 char(1)'), - generator) + + self.check_data_integrity(("col1 char(1)", "col2 char(1)"), generator) def test_bug_2671682(self): from pymysql.constants import ER + try: - self.cursor.execute("describe some_non_existent_table"); + self.cursor.execute("describe some_non_existent_table") except self.connection.ProgrammingError as msg: self.assertEqual(msg.args[0], ER.NO_SUCH_TABLE) @@ -93,7 +101,7 @@ def test_literal_int(self): self.assertTrue("2" == self.connection.literal(2)) def test_literal_float(self): - self.assertTrue("3.1415" == self.connection.literal(3.1415)) + self.assertEqual("3.1415e0", self.connection.literal(3.1415)) def test_literal_string(self): self.assertTrue("'foo'" == self.connection.literal("foo")) diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py index a26691626..5c34d40d1 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py @@ -2,23 +2,24 @@ import pymysql from pymysql.tests import base -try: - import unittest2 as unittest -except ImportError: - import unittest - class test_MySQLdb(dbapi20.DatabaseAPI20Test): driver = pymysql connect_args = () connect_kw_args = base.PyMySQLTestCase.databases[0].copy() - connect_kw_args.update(dict(read_default_file='~/.my.cnf', - charset='utf8', - sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + connect_kw_args.update( + dict( + read_default_file="~/.my.cnf", + charset="utf8", + sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL", + ) + ) + + def test_setoutputsize(self): + pass - def test_setoutputsize(self): pass - def test_setoutputsize_basic(self): pass - def test_nextset(self): pass + def test_setoutputsize_basic(self): + pass """The tests on fetchone and fetchall and rowcount bogusly test for an exception if the statement cannot return a @@ -40,36 +41,41 @@ def test_fetchall(self): # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows -## self.assertRaises(self.driver.Error,cur.fetchall) + ## self.assertRaises(self.driver.Error,cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) finally: con.close() @@ -81,39 +87,40 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows + # executing a query that cannot return rows self.executeDDL1(cur) -## self.assertRaises(self.driver.Error,cur.fetchone) + ## self.assertRaises(self.driver.Error,cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) -## self.assertRaises(self.driver.Error,cur.fetchone) + # executing a query that cannot return rows + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + ## self.assertRaises(self.driver.Error,cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) -## self.assertEqual(cur.fetchone(),None, -## 'cursor.fetchone should return None if no more rows available' -## ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + ## self.assertEqual(cur.fetchone(),None, + ## 'cursor.fetchone should return None if no more rows available' + ## ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() @@ -123,81 +130,84 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) -## self.assertEqual(cur.rowcount,-1, -## 'cursor.rowcount should be -1 after executing no-result ' -## 'statements' -## ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) -## self.assertTrue(cur.rowcount in (-1,1), -## 'cursor.rowcount should == number or rows inserted, or ' -## 'set to -1 after executing an insert statement' -## ) + ## self.assertEqual(cur.rowcount,-1, + ## 'cursor.rowcount should be -1 after executing no-result ' + ## 'statements' + ## ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + ## self.assertTrue(cur.rowcount in (-1,1), + ## 'cursor.rowcount should == number or rows inserted, or ' + ## 'set to -1 after executing an insert statement' + ## ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) -## self.assertEqual(cur.rowcount,-1, -## 'cursor.rowcount not being reset to -1 after executing ' -## 'no-result statements' -## ) + ## self.assertEqual(cur.rowcount,-1, + ## 'cursor.rowcount not being reset to -1 after executing ' + ## 'no-result statements' + ## ) finally: con.close() def test_callproc(self): - pass # performed in test_MySQL_capabilities - - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' - sql=""" + pass # performed in test_MySQL_capabilities + + def help_nextset_setUp(self, cur): + """Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + """ + sql = """ create procedure deleteme() begin select count(*) from %(tp)sbooze; select name from %(tp)sbooze; end - """ % dict(tp=self.table_prefix) + """ % dict( + tp=self.table_prefix + ) cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" cur.execute("drop procedure deleteme") def test_nextset(self): - from warnings import warn con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) + cur.callproc("deleteme") + numberofrows = cur.fetchone() + assert numberofrows[0] == len(self.samples) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() + s = cur.nextset() if s: empty = cur.fetchall() - self.assertEqual(len(empty), 0, - "non-empty result set after other result sets") - #warn("Incompatibility: MySQL returns an empty result set for the CALL itself", + self.assertEqual( + len(empty), 0, "non-empty result set after other result sets" + ) + # warn("Incompatibility: MySQL returns an empty result set for the CALL itself", # Warning) - #assert s == None,'No more return sets, should return None' + # assert s == None,'No more return sets, should return None' finally: self.help_nextset_tearDown(cur) diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py index 17fc2cde5..1545fbb5e 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py @@ -1,17 +1,10 @@ -import sys -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest import pymysql + _mysql = pymysql from pymysql.constants import FIELD_TYPE from pymysql.tests import base -from pymysql._compat import PY2, long_type - -if not PY2: - basestring = str class TestDBAPISet(unittest.TestCase): @@ -33,17 +26,17 @@ class CoreModule(unittest.TestCase): def test_NULL(self): """Should have a NULL constant.""" - self.assertEqual(_mysql.NULL, 'NULL') + self.assertEqual(_mysql.NULL, "NULL") def test_version(self): """Version information sanity.""" - self.assertTrue(isinstance(_mysql.__version__, basestring)) + self.assertTrue(isinstance(_mysql.__version__, str)) self.assertTrue(isinstance(_mysql.version_info, tuple)) self.assertEqual(len(_mysql.version_info), 5) def test_client_info(self): - self.assertTrue(isinstance(_mysql.get_client_info(), basestring)) + self.assertTrue(isinstance(_mysql.get_client_info(), str)) def test_thread_safe(self): self.assertTrue(isinstance(_mysql.thread_safe(), int)) @@ -62,40 +55,45 @@ def tearDown(self): def test_thread_id(self): tid = self.conn.thread_id() - self.assertTrue(isinstance(tid, (int, long_type)), - "thread_id didn't return an integral value.") + self.assertTrue( + isinstance(tid, int), "thread_id didn't return an integral value." + ) - self.assertRaises(TypeError, self.conn.thread_id, ('evil',), - "thread_id shouldn't accept arguments.") + self.assertRaises( + TypeError, + self.conn.thread_id, + ("evil",), + "thread_id shouldn't accept arguments.", + ) def test_affected_rows(self): - self.assertEqual(self.conn.affected_rows(), 0, - "Should return 0 before we do anything.") + self.assertEqual( + self.conn.affected_rows(), 0, "Should return 0 before we do anything." + ) - - #def test_debug(self): - ## FIXME Only actually tests if you lack SUPER - #self.assertRaises(pymysql.OperationalError, - #self.conn.dump_debug_info) + # def test_debug(self): + ## FIXME Only actually tests if you lack SUPER + # self.assertRaises(pymysql.OperationalError, + # self.conn.dump_debug_info) def test_charset_name(self): - self.assertTrue(isinstance(self.conn.character_set_name(), basestring), - "Should return a string.") + self.assertTrue( + isinstance(self.conn.character_set_name(), str), "Should return a string." + ) def test_host_info(self): - assert isinstance(self.conn.get_host_info(), basestring), "should return a string" + assert isinstance(self.conn.get_host_info(), str), "should return a string" def test_proto_info(self): - self.assertTrue(isinstance(self.conn.get_proto_info(), int), - "Should return an int.") + self.assertTrue( + isinstance(self.conn.get_proto_info(), int), "Should return an int." + ) def test_server_info(self): - if sys.version_info[0] == 2: - self.assertTrue(isinstance(self.conn.get_server_info(), basestring), - "Should return an str.") - else: - self.assertTrue(isinstance(self.conn.get_server_info(), basestring), - "Should return an str.") + self.assertTrue( + isinstance(self.conn.get_server_info(), str), "Should return an str." + ) + if __name__ == "__main__": unittest.main() diff --git a/pymysql/util.py b/pymysql/util.py deleted file mode 100644 index 3e82ac7b5..000000000 --- a/pymysql/util.py +++ /dev/null @@ -1,22 +0,0 @@ -import struct - - -def byte2int(b): - if isinstance(b, int): - return b - else: - return struct.unpack("!B", b)[0] - - -def int2byte(i): - return struct.pack("!B", i) - - -def join_bytes(bs): - if len(bs) == 0: - return "" - else: - rv = bs[0] - for b in bs[1:]: - rv += b - return rv diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8cd9ddb45 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,66 @@ +[project] +name = "PyMySQL" +description = "Pure Python MySQL Driver" +authors = [ + {name = "Inada Naoki", email = "songofacandy@gmail.com"}, + {name = "Yutaka Matsubara", email = "yutaka.matsubara@gmail.com"} +] +dependencies = [] + +requires-python = ">=3.7" +readme = "README.md" +license = {text = "MIT License"} +keywords = ["MySQL"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Topic :: Database", +] +dynamic = ["version"] + +[project.optional-dependencies] +"rsa" = [ + "cryptography" +] +"ed25519" = [ + "PyNaCl>=1.4.0" +] + +[project.urls] +"Project" = "https://github.com/PyMySQL/PyMySQL" +"Documentation" = "https://pymysql.readthedocs.io/" + +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +namespaces = false +include = ["pymysql*"] +exclude = ["tests*", "pymysql.tests*"] + +[tool.setuptools.dynamic] +version = {attr = "pymysql.VERSION_STRING"} + +[tool.ruff] +exclude = [ + "pymysql/tests/thirdparty", +] + +[tool.ruff.lint] +ignore = ["E721"] + +[tool.pdm.dev-dependencies] +dev = [ + "pytest-cov>=4.0.0", +] diff --git a/renovate.json b/renovate.json new file mode 100644 index 000000000..09e16da6b --- /dev/null +++ b/renovate.json @@ -0,0 +1,7 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "config:base" + ], + "dependencyDashboard": false +} diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..140d37067 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,4 @@ +cryptography +PyNaCl>=1.4.0 +pytest +pytest-cov diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 70f051613..000000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -cryptography - diff --git a/runtests.py b/runtests.py deleted file mode 100755 index ea3d9e8dd..000000000 --- a/runtests.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -import unittest2 - -from pymysql._compat import PYPY, JYTHON, IRONPYTHON - -#import pymysql -#pymysql.connections.DEBUG = True -#pymysql._auth.DEBUG = True - -if not (PYPY or JYTHON or IRONPYTHON): - import atexit - import gc - gc.set_debug(gc.DEBUG_UNCOLLECTABLE) - - @atexit.register - def report_uncollectable(): - import gc - if not gc.garbage: - print("No garbages!") - return - print('uncollectable objects') - for obj in gc.garbage: - print(obj) - if hasattr(obj, '__dict__'): - print(obj.__dict__) - for ref in gc.get_referrers(obj): - print("referrer:", ref) - print('---') - -import pymysql.tests -unittest2.main(pymysql.tests, verbosity=2) diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a26a846b9..000000000 --- a/setup.cfg +++ /dev/null @@ -1,17 +0,0 @@ -[flake8] -ignore = E226,E301,E701 -exclude = tests,build -max-line-length = 119 - -[bdist_wheel] -universal = 1 - -[metadata] -license = "MIT" -license_file = LICENSE - -author=yutaka.matsubara -author_email=yutaka.matsubara@gmail.com - -maintainer=INADA Naoki -maintainer_email=songofacandy@gmail.com diff --git a/setup.py b/setup.py deleted file mode 100755 index 14650d1c7..000000000 --- a/setup.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python -import io -from setuptools import setup, find_packages - -version = "0.9.2" - -with io.open('./README.rst', encoding='utf-8') as f: - readme = f.read() - -setup( - name="PyMySQL", - version=version, - url='https://github.com/PyMySQL/PyMySQL/', - project_urls={ - "Documentation": "https://pymysql.readthedocs.io/", - }, - description='Pure Python MySQL Driver', - long_description=readme, - packages=find_packages(exclude=['tests*', 'pymysql.tests*']), - install_requires=[ - "cryptography", - ], - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Topic :: Database', - ], - keywords="MySQL", -) diff --git a/tests/test_auth.py b/tests/test_auth.py index 7d8573442..e5e2a64e5 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -10,7 +10,10 @@ port = 3306 ca = os.path.expanduser("~/ca.pem") -ssl = {'ca': ca, 'check_hostname': False} +ssl = {"ca": ca, "check_hostname": False} + +pass_sha256 = "pass_sha256_01234567890123456789" +pass_caching_sha2 = "pass_caching_sha2_01234567890123456789" def test_sha256_no_password(): @@ -24,12 +27,16 @@ def test_sha256_no_passowrd_ssl(): def test_sha256_password(): - con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_sha256", password=pass_sha256, host=host, port=port, ssl=None + ) con.close() def test_sha256_password_ssl(): - con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=ssl) + con = pymysql.connect( + user="user_sha256", password=pass_sha256, host=host, port=port, ssl=ssl + ) con.close() @@ -38,26 +45,50 @@ def test_caching_sha2_no_password(): con.close() -def test_caching_sha2_no_password(): +def test_caching_sha2_no_password_ssl(): con = pymysql.connect(user="nopass_caching_sha2", host=host, port=port, ssl=ssl) con.close() def test_caching_sha2_password(): - con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=None, + ) con.close() # Fast path of caching sha2 - con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=None, + ) con.query("FLUSH PRIVILEGES") con.close() def test_caching_sha2_password_ssl(): - con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=ssl) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=ssl, + ) con.close() # Fast path of caching sha2 - con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None) + con = pymysql.connect( + user="user_caching_sha2", + password=pass_caching_sha2, + host=host, + port=port, + ssl=None, + ) con.query("FLUSH PRIVILEGES") con.close() diff --git a/tox.ini b/tox.ini deleted file mode 100644 index a50364c9c..000000000 --- a/tox.ini +++ /dev/null @@ -1,10 +0,0 @@ -[tox] -envlist = py26,py27,py33,py34,pypy,pypy3 - -[testenv] -commands = coverage run ./runtests.py -deps = unittest2 - coverage -passenv = USER - PASSWORD - PAMSERVICE