diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 000000000..253a13aca
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,12 @@
+# These are supported funding model platforms
+
+github: ["methane"] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: # Replace with a single Ko-fi username
+tidelift: "pypi/PyMySQL" # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
deleted file mode 100644
index 3e0fbe826..000000000
--- a/.github/ISSUE_TEMPLATE.md
+++ /dev/null
@@ -1,11 +0,0 @@
-This project is maintained one busy person with a frail wife and an infant daughter.
-My time and energy is a very limited resource. I'm not a teacher or free tech support.
-Don't ask a question here. Don't file an issue until you believe it's a not a problem with your code.
-Search for friendly volunteers who can teach you or review your code on ML or Q&A sites.
-
-See also: https://medium.com/@methane/why-you-must-not-ask-questions-on-github-issues-51d741d83fde
-
-
-If you're sure it's PyMySQL's issue, report the complete steps to reproduce, from creating database.
-
-I don't have time to investigate your issue from an incomplete code snippet.
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 000000000..f2bd4d300
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,39 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Describe the bug**
+A clear and concise description of what the bug is.
+
+**To Reproduce**
+Complete steps to reproduce the behavior:
+
+Schema:
+
+```
+CREATE DATABASE ...
+CREATE TABLE ...
+```
+
+Code:
+
+```py
+import pymysql
+con = pymysql.connect(...)
+```
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+**Environment**
+ - OS: [e.g. Windows, Linux]
+ - Server and version: [e.g. MySQL 8.0.19, MariaDB]
+ - PyMySQL version:
+
+**Additional context**
+Add any other context about the problem here.
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
new file mode 100644
index 000000000..df49979ea
--- /dev/null
+++ b/.github/workflows/codeql-analysis.yml
@@ -0,0 +1,62 @@
+# For most projects, this workflow file will not need changing; you simply need
+# to commit it to your repository.
+#
+# You may wish to alter this file to override the set of languages analyzed,
+# or to provide custom queries or build logic.
+#
+# ******** NOTE ********
+# We have attempted to detect the languages in your repository. Please check
+# the `language` matrix defined below to confirm you have the correct set of
+# supported CodeQL languages.
+#
+name: "CodeQL"
+
+on:
+ push:
+ branches: [ main ]
+ pull_request:
+ # The branches below must be a subset of the branches above
+ branches: [ main ]
+ schedule:
+ - cron: '34 7 * * 2'
+
+jobs:
+ analyze:
+ name: Analyze
+ runs-on: ubuntu-latest
+
+ strategy:
+ fail-fast: false
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ # Initializes the CodeQL tools for scanning.
+ - name: Initialize CodeQL
+ uses: github/codeql-action/init@v3
+ with:
+ languages: "python"
+ # If you wish to specify custom queries, you can do so here or in a config file.
+ # By default, queries listed here will override any specified in a config file.
+ # Prefix the list here with "+" to use these queries and those in the config file.
+ # queries: ./path/to/local/query, your-org/your-repo/queries@main
+
+ # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
+ # If this step fails, then you should remove it and run the build manually (see below)
+ - name: Autobuild
+ uses: github/codeql-action/autobuild@v3
+
+ # âšī¸ Command-line programs to run using the OS shell.
+ # đ https://git.io/JvXDl
+
+ # âī¸ If the Autobuild fails above, remove it and uncomment the following three lines
+ # and modify them (or add more) to build your code if your project
+ # uses a compiled language
+
+ #- run: |
+ # make bootstrap
+ # make release
+
+ - name: Perform CodeQL Analysis
+ uses: github/codeql-action/analyze@v3
diff --git a/.github/workflows/codesee-arch-diagram.yml b/.github/workflows/codesee-arch-diagram.yml
new file mode 100644
index 000000000..806d41d12
--- /dev/null
+++ b/.github/workflows/codesee-arch-diagram.yml
@@ -0,0 +1,23 @@
+# This workflow was added by CodeSee. Learn more at https://codesee.io/
+# This is v2.0 of this workflow file
+on:
+ push:
+ branches:
+ - main
+ pull_request_target:
+ types: [opened, synchronize, reopened]
+
+name: CodeSee
+
+permissions: read-all
+
+jobs:
+ codesee:
+ runs-on: ubuntu-latest
+ continue-on-error: true
+ name: Analyze the repo with CodeSee
+ steps:
+ - uses: Codesee-io/codesee-action@v2
+ with:
+ codesee-token: ${{ secrets.CODESEE_ARCH_DIAG_API_TOKEN }}
+ codesee-url: https://app.codesee.io
diff --git a/.github/workflows/django.yaml b/.github/workflows/django.yaml
new file mode 100644
index 000000000..5c4609543
--- /dev/null
+++ b/.github/workflows/django.yaml
@@ -0,0 +1,66 @@
+name: Django test
+
+on:
+ push:
+ # branches: ["main"]
+ # pull_request:
+
+jobs:
+ django-test:
+ name: "Run Django LTS test suite"
+ runs-on: ubuntu-latest
+ # There are some known difference between MySQLdb and PyMySQL.
+ continue-on-error: true
+ env:
+ PIP_NO_PYTHON_VERSION_WARNING: 1
+ PIP_DISABLE_PIP_VERSION_CHECK: 1
+ # DJANGO_VERSION: "3.2.19"
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ # Django 3.2.9+ supports Python 3.10
+ # https://docs.djangoproject.com/ja/3.2/releases/3.2/
+ - django: "3.2.19"
+ python: "3.10"
+
+ - django: "4.2.1"
+ python: "3.11"
+
+ steps:
+ - name: Start MySQL
+ run: |
+ sudo systemctl start mysql.service
+ mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -uroot -proot mysql
+ mysql -uroot -proot -e "set global innodb_flush_log_at_trx_commit=0;"
+ mysql -uroot -proot -e "CREATE USER 'scott'@'%' IDENTIFIED BY 'tiger'; GRANT ALL ON *.* TO scott;"
+ mysql -uroot -proot -e "CREATE DATABASE django_default; CREATE DATABASE django_other;"
+
+ - uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python }}
+
+ - name: Install mysqlclient
+ run: |
+ #pip install mysqlclient # Use stable version
+ pip install .[rsa]
+
+ - name: Setup Django
+ run: |
+ sudo apt-get install libmemcached-dev
+ wget https://github.com/django/django/archive/${{ matrix.django }}.tar.gz
+ tar xf ${{ matrix.django }}.tar.gz
+ cp ci/test_mysql.py django-${{ matrix.django }}/tests/
+ cd django-${{ matrix.django }}
+ pip install . -r tests/requirements/py3.txt
+
+ - name: Run Django test
+ run: |
+ cd django-${{ matrix.django }}/tests/
+ # test_runner does not using our test_mysql.py
+ # We can't run whole django test suite for now.
+ # Run olly backends test
+ DJANGO_SETTINGS_MODULE=test_mysql python runtests.py backends
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 000000000..269211c25
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,25 @@
+name: Lint
+
+on:
+ push:
+ branches: ["main"]
+ paths:
+ - '**.py'
+ pull_request:
+ paths:
+ - '**.py'
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - name: checkout
+ uses: actions/checkout@v4
+
+ - name: lint
+ uses: chartboost/ruff-action@v1
+
+ - name: check format
+ uses: chartboost/ruff-action@v1
+ with:
+ args: "format --diff"
diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml
new file mode 100644
index 000000000..21449e3b8
--- /dev/null
+++ b/.github/workflows/lock.yml
@@ -0,0 +1,17 @@
+name: 'Lock Threads'
+
+on:
+ schedule:
+ - cron: '30 9 * * 1'
+
+permissions:
+ issues: write
+ pull-requests: write
+
+jobs:
+ lock-threads:
+ if: github.repository == 'PyMySQL/PyMySQL'
+ runs-on: ubuntu-latest
+ steps:
+ - uses: dessant/lock-threads@v5
+
diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
new file mode 100644
index 000000000..6d59d8c4b
--- /dev/null
+++ b/.github/workflows/test.yaml
@@ -0,0 +1,109 @@
+name: Test
+
+on:
+ push:
+ pull_request:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
+ cancel-in-progress: true
+
+env:
+ FORCE_COLOR: 1
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - db: "mariadb:10.4"
+ py: "3.8"
+
+ - db: "mariadb:10.5"
+ py: "3.7"
+
+ - db: "mariadb:10.6"
+ py: "3.11"
+
+ - db: "mariadb:10.6"
+ py: "3.12"
+
+ - db: "mariadb:lts"
+ py: "3.9"
+
+ - db: "mysql:5.7"
+ py: "pypy-3.8"
+
+ - db: "mysql:8.0"
+ py: "3.9"
+ mysql_auth: true
+
+ - db: "mysql:8.0"
+ py: "3.10"
+
+ services:
+ mysql:
+ image: "${{ matrix.db }}"
+ ports:
+ - 3306:3306
+ env:
+ MYSQL_ALLOW_EMPTY_PASSWORD: yes
+ MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: yes
+ options: "--name=mysqld"
+ volumes:
+ - /run/mysqld:/run/mysqld
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Workaround MySQL container permissions
+ if: startsWith(matrix.db, 'mysql')
+ run: |
+ sudo chown 999:999 /run/mysqld
+ /usr/bin/docker ps --all --filter status=exited --no-trunc --format "{{.ID}}" | xargs -r /usr/bin/docker start
+
+ - name: Set up Python ${{ matrix.py }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.py }}
+ allow-prereleases: true
+ cache: 'pip'
+ cache-dependency-path: 'requirements-dev.txt'
+
+ - name: Install dependency
+ run: |
+ pip install --upgrade -r requirements-dev.txt
+
+ - name: Set up MySQL
+ run: |
+ while :
+ do
+ sleep 1
+ mysql -h127.0.0.1 -uroot -e 'select version()' && break
+ done
+ mysql -h127.0.0.1 -uroot -e "SET GLOBAL local_infile=on"
+ mysql -h127.0.0.1 -uroot --comments < ci/docker-entrypoint-initdb.d/init.sql
+ mysql -h127.0.0.1 -uroot --comments < ci/docker-entrypoint-initdb.d/mysql.sql
+ mysql -h127.0.0.1 -uroot --comments < ci/docker-entrypoint-initdb.d/mariadb.sql
+ cp ci/docker.json pymysql/tests/databases.json
+
+ - name: Run test
+ run: |
+ pytest -v --cov --cov-config .coveragerc pymysql
+ pytest -v --cov-append --cov-config .coveragerc --doctest-modules pymysql/converters.py
+
+ - name: Run MySQL8 auth test
+ if: ${{ matrix.mysql_auth }}
+ run: |
+ docker cp mysqld:/var/lib/mysql/public_key.pem "${HOME}"
+ docker cp mysqld:/var/lib/mysql/ca.pem "${HOME}"
+ docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}"
+ docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}"
+ docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}"
+ pytest -v --cov-append --cov-config .coveragerc tests/test_auth.py;
+
+ - name: Upload coverage reports to Codecov
+ if: github.repository == 'PyMySQL/PyMySQL'
+ uses: codecov/codecov-action@v4
diff --git a/.gitignore b/.gitignore
index 98f4d45c8..09a5654fb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,3 +13,4 @@
/pymysql/tests/databases.json
__pycache__
Pipfile.lock
+pdm.lock
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 000000000..59fdb65df
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,17 @@
+# .readthedocs.yaml
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+version: 2
+
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.12"
+
+python:
+ install:
+ - requirements: docs/requirements.txt
+
+# Build documentation in the docs/ directory with Sphinx
+sphinx:
+ configuration: docs/source/conf.py
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index f4a7cc74e..000000000
--- a/.travis.yml
+++ /dev/null
@@ -1,67 +0,0 @@
-# vim: sw=2 ts=2 sts=2 expandtab
-
-sudo: required
-language: python
-services:
- - docker
-
-cache: pip
-
-matrix:
- include:
- - env:
- - DB=mariadb:5.5
- python: "3.5"
- - env:
- - DB=mariadb:10.0
- python: "3.6"
- - env:
- - DB=mariadb:10.1
- python: "pypy"
- - env:
- - DB=mariadb:10.2
- python: "2.7"
- - env:
- - DB=mariadb:10.3
- python: "3.7-dev"
- - env:
- - DB=mysql:5.5
- python: "3.5"
- - env:
- - DB=mysql:5.6
- python: "3.6"
- - env:
- - DB=mysql:5.7
- python: "3.4"
- - env:
- - DB=mysql:8.0
- - TEST_AUTH=yes
- python: "3.7-dev"
- - env:
- - DB=mysql:8.0
- - TEST_AUTH=yes
- python: "2.7"
-
-# different py version from 5.6 and 5.7 as cache seems to be based on py version
-# http://dev.mysql.com/downloads/mysql/5.7.html has latest development release version
-# really only need libaio1 for DB builds however libaio-dev is whitelisted for container builds and liaio1 isn't
-install:
- - pip install -U coveralls unittest2 coverage cryptography pytest pytest-cov
-
-before_script:
- - ./.travis/initializedb.sh
- - python -VV
- - rm -f ~/.my.cnf # set in .travis.initialize.db.sh for the above commands - we should be using database.json however
- - export COVERALLS_PARALLEL=true
-
-script:
- - coverage run ./runtests.py
- - if [ "${TEST_AUTH}" = "yes" ];
- then pytest -v --cov-config .coveragerc tests;
- fi
- - if [ ! -z "${DB}" ];
- then docker logs mysqld;
- fi
-
-after_success:
- - coveralls
diff --git a/.travis/database.json b/.travis/database.json
deleted file mode 100644
index ab1f60a3a..000000000
--- a/.travis/database.json
+++ /dev/null
@@ -1,4 +0,0 @@
-[
- {"host": "localhost", "unix_socket": "/var/run/mysqld/mysqld.sock", "user": "root", "passwd": "", "db": "test1", "use_unicode": true, "local_infile": true},
- {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "db": "test2" }
-]
diff --git a/.travis/docker.json b/.travis/docker.json
deleted file mode 100644
index b851fb6da..000000000
--- a/.travis/docker.json
+++ /dev/null
@@ -1,4 +0,0 @@
-[
- {"host": "127.0.0.1", "port": 3306, "user": "root", "passwd": "", "db": "test1", "use_unicode": true, "local_infile": true},
- {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "db": "test2" }
-]
diff --git a/.travis/initializedb.sh b/.travis/initializedb.sh
deleted file mode 100755
index d9897e49c..000000000
--- a/.travis/initializedb.sh
+++ /dev/null
@@ -1,72 +0,0 @@
-#!/bin/bash
-
-#debug
-set -x
-#verbose
-set -v
-
-if [ ! -z "${DB}" ]; then
- # disable existing database server in case of accidential connection
- sudo service mysql stop
-
- docker pull ${DB}
- docker run -it --name=mysqld -d -e MYSQL_ALLOW_EMPTY_PASSWORD=yes -p 3306:3306 ${DB}
- sleep 10
-
- mysql() {
- docker exec mysqld mysql "${@}"
- }
- while :
- do
- sleep 5
- mysql -e 'select version()'
- if [ $? = 0 ]; then
- break
- fi
- echo "server logs"
- docker logs --tail 5 mysqld
- done
-
- mysql -e 'select VERSION()'
-
- if [ $DB == 'mysql:8.0' ]; then
- WITH_PLUGIN='with mysql_native_password'
- mysql -e 'SET GLOBAL local_infile=on'
- docker cp mysqld:/var/lib/mysql/public_key.pem "${HOME}"
- docker cp mysqld:/var/lib/mysql/ca.pem "${HOME}"
- docker cp mysqld:/var/lib/mysql/server-cert.pem "${HOME}"
- docker cp mysqld:/var/lib/mysql/client-key.pem "${HOME}"
- docker cp mysqld:/var/lib/mysql/client-cert.pem "${HOME}"
-
- # Test user for auth test
- mysql -e '
- CREATE USER
- user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256",
- nopass_sha256 IDENTIFIED WITH "sha256_password",
- user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2",
- nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password"
- PASSWORD EXPIRE NEVER;'
- mysql -e 'GRANT RELOAD ON *.* TO user_caching_sha2;'
- else
- WITH_PLUGIN=''
- fi
-
- mysql -uroot -e 'create database test1 DEFAULT CHARACTER SET utf8mb4'
- mysql -uroot -e 'create database test2 DEFAULT CHARACTER SET utf8mb4'
-
- mysql -u root -e "create user test2 identified ${WITH_PLUGIN} by 'some password'; grant all on test2.* to test2;"
- mysql -u root -e "create user test2@localhost identified ${WITH_PLUGIN} by 'some password'; grant all on test2.* to test2@localhost;"
-
- cp .travis/docker.json pymysql/tests/databases.json
-else
- cat ~/.my.cnf
-
- mysql -e 'select VERSION()'
- mysql -e 'create database test1 DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'
- mysql -e 'create database test2 DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'
-
- mysql -u root -e "create user test2 identified by 'some password'; grant all on test2.* to test2;"
- mysql -u root -e "create user test2@localhost identified by 'some password'; grant all on test2.* to test2@localhost;"
-
- cp .travis/database.json pymysql/tests/databases.json
-fi
diff --git a/CHANGELOG b/CHANGELOG.md
similarity index 61%
rename from CHANGELOG
rename to CHANGELOG.md
index b4372fed6..825dc47c1 100644
--- a/CHANGELOG
+++ b/CHANGELOG.md
@@ -1,5 +1,144 @@
# Changes
+## Backward incompatible changes planned in the future.
+
+* Error classes in Cursor class will be removed after 2024-06
+* `Connection.set_charset(charset)` will be removed after 2024-06
+* `db` and `passwd` will emit DeprecationWarning in v1.2. See #933.
+* `Connection.ping(reconnect)` change the default to not reconnect.
+
+## v1.1.1
+
+Release date: 2024-05-21
+
+> [!WARNING]
+> This release fixes a vulnerability (CVE-2024-36039).
+> All users are recommended to update to this version.
+>
+> If you can not update soon, check the input value from
+> untrusted source has an expected type. Only dict input
+> from untrusted source can be an attack vector.
+
+* Prohibit dict parameter for `Cursor.execute()`. It didn't produce valid SQL
+ and might cause SQL injection. (CVE-2024-36039)
+
+## v1.1.0
+
+Release date: 2023-06-26
+
+* Fixed SSCursor raising OperationalError for query timeouts on wrong statement (#1032)
+* Exposed `Cursor.warning_count` to check for warnings without additional query (#1056)
+* Make Cursor iterator (#995)
+* Support '_' in key name in my.cnf (#1114)
+* `Cursor.fetchall()` returns empty list instead of tuple (#1115). Note that `Cursor.fetchmany()` still return empty tuple after reading all rows for compatibility with Django.
+* Deprecate Error classes in Cursor class (#1117)
+* Add `Connection.set_character_set(charset, collation=None)`. This method is compatible with mysqlclient. (#1119)
+* Deprecate `Connection.set_charset(charset)` (#1119)
+* New connection always send "SET NAMES charset [COLLATE collation]" query. (#1119)
+ Since collation table is vary on MySQL server versions, collation in handshake is fragile.
+* Support `charset="utf8mb3"` option (#1127)
+
+
+## v1.0.3
+
+Release date: 2023-03-28
+
+* Dropped support of end of life MySQL version 5.6
+* Dropped support of end of life MariaDB versions below 10.3
+* Dropped support of end of life Python version 3.6
+* Removed `_last_executed` because of duplication with `_executed` by @rajat315315 in https://github.com/PyMySQL/PyMySQL/pull/948
+* Fix generating authentication response with long strings by @netch80 in https://github.com/PyMySQL/PyMySQL/pull/988
+* update pymysql.constants.CR by @Nothing4You in https://github.com/PyMySQL/PyMySQL/pull/1029
+* Document that the ssl connection parameter can be an SSLContext by @cakemanny in https://github.com/PyMySQL/PyMySQL/pull/1045
+* Raise ProgrammingError on -np.inf in addition to np.inf by @cdcadman in https://github.com/PyMySQL/PyMySQL/pull/1067
+* Use Python 3.11 release instead of -dev in tests by @Nothing4You in https://github.com/PyMySQL/PyMySQL/pull/1076
+
+
+## v1.0.2
+
+Release date: 2021-01-09
+
+* Fix `user`, `password`, `host`, `database` are still positional arguments.
+ All arguments of `connect()` are now keyword-only. (#941)
+
+
+## v1.0.1
+
+Release date: 2021-01-08
+
+* Stop emitting DeprecationWarning for use of ``db`` and ``passwd``.
+ Note that they are still deprecated. (#939)
+* Add ``python_requires=">=3.6"`` to setup.py. (#936)
+
+
+## v1.0.0
+
+Release date: 2021-01-07
+
+Backward incompatible changes:
+
+* Python 2.7 and 3.5 are not supported.
+* ``connect()`` uses keyword-only arguments. User must use keyword argument.
+* ``connect()`` kwargs ``db`` and ``passwd`` are now deprecated; Use ``database`` and ``password`` instead.
+* old_password authentication method (used by MySQL older than 4.1) is not supported.
+* MySQL 5.5 and MariaDB 5.5 are not officially supported, although it may still works.
+* Removed ``escape_dict``, ``escape_sequence``, and ``escape_string`` from ``pymysql``
+ module. They are still in ``pymysql.converters``.
+
+Other changes:
+
+* Connection supports context manager API. ``__exit__`` closes the connection. (#886)
+* Add MySQL Connector/Python compatible TLS options (#903)
+* Major code cleanup; PyMySQL now uses black and flake8.
+
+
+## v0.10.1
+
+Release date: 2020-09-10
+
+* Fix missing import of ProgrammingError. (#878)
+* Fix auth switch request handling. (#890)
+
+
+## v0.10.0
+
+Release date: 2020-07-18
+
+This version is the last version supporting Python 2.7.
+
+* MariaDB ed25519 auth is supported.
+* Python 3.4 support is dropped.
+* Context manager interface is removed from `Connection`. It will be added
+ with different meaning.
+* MySQL warnings are not shown by default because many user report issue to
+ PyMySQL issue tracker when they see warning. You need to call "SHOW WARNINGS"
+ explicitly when you want to see warnings.
+* Formatting of float object is changed from "3.14" to "3.14e0".
+* Use cp1252 codec for latin1 charset.
+* Fix decimal literal.
+* TRUNCATED_WRONG_VALUE_FOR_FIELD, and ILLEGAL_VALUE_FOR_TYPE are now
+ DataError instead of InternalError.
+
+
+## 0.9.3
+
+Release date: 2018-12-18
+
+* cryptography dependency is optional now.
+* Fix old_password (used before MySQL 4.1) support.
+* Deprecate old_password.
+* Stop sending ``sys.argv[0]`` for connection attribute "program_name".
+* Close connection when unknown error is happened.
+* Deprecate context manager API of Connection object.
+
+## 0.9.2
+
+Release date: 2018-07-04
+
+* Disabled unintentinally enabled debug log
+* Removed unintentionally installed tests
+
+
## 0.9.1
Release date: 2018-07-03
@@ -32,7 +171,7 @@ Release date: 2018-05-07
* Many test suite improvements, especially adding MySQL 8.0 and using Docker.
Thanks to Daniel Black.
-* Droppped support for old Python and MySQL whih is not tested long time.
+* Dropped support for old Python and MySQL which is not tested long time.
## 0.8
@@ -110,7 +249,7 @@ Release date: 2016-08-30
Release date: 2016-07-29
* Fix SELECT JSON type cause UnicodeError
-* Avoid float convertion while parsing microseconds
+* Avoid float conversion while parsing microseconds
* Warning has number
* SSCursor supports warnings
diff --git a/MANIFEST.in b/MANIFEST.in
index 0a5207928..e2e577a9d 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1 +1 @@
-include README.rst LICENSE CHANGELOG
+include README.md LICENSE CHANGELOG.md
diff --git a/Pipfile b/Pipfile
deleted file mode 100644
index 0e142ba35..000000000
--- a/Pipfile
+++ /dev/null
@@ -1,12 +0,0 @@
-[[source]]
-url = "https://pypi.python.org/simple"
-verify_ssl = true
-name = "pypi"
-
-[packages]
-cryptography = "*"
-
-[dev-packages]
-pytest = "*"
-unittest2 = "*"
-twine = "*"
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..32f5df2f4
--- /dev/null
+++ b/README.md
@@ -0,0 +1,105 @@
+[](https://pymysql.readthedocs.io/)
+[](https://codecov.io/gh/PyMySQL/PyMySQL)
+
+# PyMySQL
+
+This package contains a pure-Python MySQL client library, based on [PEP
+249](https://www.python.org/dev/peps/pep-0249/).
+
+## Requirements
+
+- Python -- one of the following:
+ - [CPython](https://www.python.org/) : 3.7 and newer
+ - [PyPy](https://pypy.org/) : Latest 3.x version
+- MySQL Server -- one of the following:
+ - [MySQL](https://www.mysql.com/) \>= 5.7
+ - [MariaDB](https://mariadb.org/) \>= 10.4
+
+## Installation
+
+Package is uploaded on [PyPI](https://pypi.org/project/PyMySQL).
+
+You can install it with pip:
+
+ $ python3 -m pip install PyMySQL
+
+To use "sha256_password" or "caching_sha2_password" for authenticate,
+you need to install additional dependency:
+
+ $ python3 -m pip install PyMySQL[rsa]
+
+To use MariaDB's "ed25519" authentication method, you need to install
+additional dependency:
+
+ $ python3 -m pip install PyMySQL[ed25519]
+
+## Documentation
+
+Documentation is available online:
+
+For support, please refer to the
+[StackOverflow](https://stackoverflow.com/questions/tagged/pymysql).
+
+## Example
+
+The following examples make use of a simple table
+
+``` sql
+CREATE TABLE `users` (
+ `id` int(11) NOT NULL AUTO_INCREMENT,
+ `email` varchar(255) COLLATE utf8_bin NOT NULL,
+ `password` varchar(255) COLLATE utf8_bin NOT NULL,
+ PRIMARY KEY (`id`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
+AUTO_INCREMENT=1 ;
+```
+
+``` python
+import pymysql.cursors
+
+# Connect to the database
+connection = pymysql.connect(host='localhost',
+ user='user',
+ password='passwd',
+ database='db',
+ cursorclass=pymysql.cursors.DictCursor)
+
+with connection:
+ with connection.cursor() as cursor:
+ # Create a new record
+ sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)"
+ cursor.execute(sql, ('webmaster@python.org', 'very-secret'))
+
+ # connection is not autocommit by default. So you must commit to save
+ # your changes.
+ connection.commit()
+
+ with connection.cursor() as cursor:
+ # Read a single record
+ sql = "SELECT `id`, `password` FROM `users` WHERE `email`=%s"
+ cursor.execute(sql, ('webmaster@python.org',))
+ result = cursor.fetchone()
+ print(result)
+```
+
+This example will print:
+
+``` python
+{'password': 'very-secret', 'id': 1}
+```
+
+## Resources
+
+- DB-API 2.0:
+- MySQL Reference Manuals:
+- MySQL client/server protocol:
+
+- "Connector" channel in MySQL Community Slack:
+
+- PyMySQL mailing list:
+
+
+## License
+
+PyMySQL is released under the MIT License. See LICENSE for more
+information.
diff --git a/README.rst b/README.rst
deleted file mode 100644
index 1c7fba54c..000000000
--- a/README.rst
+++ /dev/null
@@ -1,145 +0,0 @@
-.. image:: https://readthedocs.org/projects/pymysql/badge/?version=latest
- :target: https://pymysql.readthedocs.io/
- :alt: Documentation Status
-
-.. image:: https://badge.fury.io/py/PyMySQL.svg
- :target: https://badge.fury.io/py/PyMySQL
-
-.. image:: https://travis-ci.org/PyMySQL/PyMySQL.svg?branch=master
- :target: https://travis-ci.org/PyMySQL/PyMySQL
-
-.. image:: https://coveralls.io/repos/PyMySQL/PyMySQL/badge.svg?branch=master&service=github
- :target: https://coveralls.io/github/PyMySQL/PyMySQL?branch=master
-
-.. image:: https://img.shields.io/badge/license-MIT-blue.svg
- :target: https://github.com/PyMySQL/PyMySQL/blob/master/LICENSE
-
-
-PyMySQL
-=======
-
-.. contents:: Table of Contents
- :local:
-
-This package contains a pure-Python MySQL client library, based on `PEP 249`_.
-
-Most public APIs are compatible with mysqlclient and MySQLdb.
-
-NOTE: PyMySQL doesn't support low level APIs `_mysql` provides like `data_seek`,
-`store_result`, and `use_result`. You should use high level APIs defined in `PEP 249`_.
-But some APIs like `autocommit` and `ping` are supported because `PEP 249`_ doesn't cover
-their usecase.
-
-.. _`PEP 249`: https://www.python.org/dev/peps/pep-0249/
-
-
-Requirements
--------------
-
-* Python -- one of the following:
-
- - CPython_ : 2.7 and >= 3.4
- - PyPy_ : Latest version
-
-* MySQL Server -- one of the following:
-
- - MySQL_ >= 5.5
- - MariaDB_ >= 5.5
-
-.. _CPython: https://www.python.org/
-.. _PyPy: https://pypy.org/
-.. _MySQL: https://www.mysql.com/
-.. _MariaDB: https://mariadb.org/
-
-
-Installation
-------------
-
-Package is uploaded on `PyPI `_.
-
-You can install it with pip::
-
- $ pip3 install PyMySQL
-
-
-Documentation
--------------
-
-Documentation is available online: https://pymysql.readthedocs.io/
-
-For support, please refer to the `StackOverflow
-`_.
-
-Example
--------
-
-The following examples make use of a simple table
-
-.. code:: sql
-
- CREATE TABLE `users` (
- `id` int(11) NOT NULL AUTO_INCREMENT,
- `email` varchar(255) COLLATE utf8_bin NOT NULL,
- `password` varchar(255) COLLATE utf8_bin NOT NULL,
- PRIMARY KEY (`id`)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin
- AUTO_INCREMENT=1 ;
-
-
-.. code:: python
-
- import pymysql.cursors
-
- # Connect to the database
- connection = pymysql.connect(host='localhost',
- user='user',
- password='passwd',
- db='db',
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor)
-
- try:
- with connection.cursor() as cursor:
- # Create a new record
- sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)"
- cursor.execute(sql, ('webmaster@python.org', 'very-secret'))
-
- # connection is not autocommit by default. So you must commit to save
- # your changes.
- connection.commit()
-
- with connection.cursor() as cursor:
- # Read a single record
- sql = "SELECT `id`, `password` FROM `users` WHERE `email`=%s"
- cursor.execute(sql, ('webmaster@python.org',))
- result = cursor.fetchone()
- print(result)
- finally:
- connection.close()
-
-This example will print:
-
-.. code:: python
-
- {'password': 'very-secret', 'id': 1}
-
-
-Resources
----------
-
-* DB-API 2.0: http://www.python.org/dev/peps/pep-0249
-
-* MySQL Reference Manuals: http://dev.mysql.com/doc/
-
-* MySQL client/server protocol:
- http://dev.mysql.com/doc/internals/en/client-server-protocol.html
-
-* "Connector" channel in MySQL Community Slack:
- http://lefred.be/mysql-community-on-slack/
-
-* PyMySQL mailing list: https://groups.google.com/forum/#!forum/pymysql-users
-
-License
--------
-
-PyMySQL is released under the MIT License. See LICENSE for more information.
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 000000000..da9c516dd
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,5 @@
+## Security contact information
+
+To report a security vulnerability, please use the
+[Tidelift security contact](https://tidelift.com/security).
+Tidelift will coordinate the fix and disclosure.
diff --git a/ci/database.json b/ci/database.json
new file mode 100644
index 000000000..aad0bfb29
--- /dev/null
+++ b/ci/database.json
@@ -0,0 +1,4 @@
+[
+ {"host": "localhost", "unix_socket": "/var/run/mysqld/mysqld.sock", "user": "root", "password": "", "database": "test1", "use_unicode": true, "local_infile": true},
+ {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "database": "test2" }
+]
diff --git a/ci/docker-entrypoint-initdb.d/README b/ci/docker-entrypoint-initdb.d/README
new file mode 100644
index 000000000..6a54b93da
--- /dev/null
+++ b/ci/docker-entrypoint-initdb.d/README
@@ -0,0 +1,12 @@
+To test with a MariaDB or MySQL container image:
+
+docker run -d -p 3306:3306 -e MYSQL_ALLOW_EMPTY_PASSWORD=1 \
+ --name=mysqld -v ./ci/docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d:z \
+ mysql:8.0.26 --local-infile=1
+
+cp ci/docker.json pymysql/tests/databases.json
+
+pytest
+
+
+Note: Some authentication tests that don't match the image version will fail.
diff --git a/ci/docker-entrypoint-initdb.d/init.sql b/ci/docker-entrypoint-initdb.d/init.sql
new file mode 100644
index 000000000..b741d41c5
--- /dev/null
+++ b/ci/docker-entrypoint-initdb.d/init.sql
@@ -0,0 +1,7 @@
+create database test1 DEFAULT CHARACTER SET utf8mb4;
+create database test2 DEFAULT CHARACTER SET utf8mb4;
+create user test2 identified by 'some password';
+grant all on test2.* to test2;
+create user test2@localhost identified by 'some password';
+grant all on test2.* to test2@localhost;
+
diff --git a/ci/docker-entrypoint-initdb.d/mariadb.sql b/ci/docker-entrypoint-initdb.d/mariadb.sql
new file mode 100644
index 000000000..912d365a9
--- /dev/null
+++ b/ci/docker-entrypoint-initdb.d/mariadb.sql
@@ -0,0 +1,2 @@
+/*M!100122 INSTALL SONAME "auth_ed25519" */;
+/*M!100122 CREATE FUNCTION ed25519_password RETURNS STRING SONAME "auth_ed25519.so" */;
diff --git a/ci/docker-entrypoint-initdb.d/mysql.sql b/ci/docker-entrypoint-initdb.d/mysql.sql
new file mode 100644
index 000000000..a4ba0927d
--- /dev/null
+++ b/ci/docker-entrypoint-initdb.d/mysql.sql
@@ -0,0 +1,8 @@
+/*!80001 CREATE USER
+ user_sha256 IDENTIFIED WITH "sha256_password" BY "pass_sha256_01234567890123456789",
+ nopass_sha256 IDENTIFIED WITH "sha256_password",
+ user_caching_sha2 IDENTIFIED WITH "caching_sha2_password" BY "pass_caching_sha2_01234567890123456789",
+ nopass_caching_sha2 IDENTIFIED WITH "caching_sha2_password"
+ PASSWORD EXPIRE NEVER */;
+
+/*!80001 GRANT RELOAD ON *.* TO user_caching_sha2 */;
diff --git a/ci/docker.json b/ci/docker.json
new file mode 100644
index 000000000..63d19a687
--- /dev/null
+++ b/ci/docker.json
@@ -0,0 +1,5 @@
+[
+ {"host": "127.0.0.1", "port": 3306, "user": "root", "password": "", "database": "test1", "use_unicode": true, "local_infile": true},
+ {"host": "127.0.0.1", "port": 3306, "user": "test2", "password": "some password", "database": "test2" },
+ {"host": "localhost", "port": 3306, "user": "test2", "password": "some password", "database": "test2", "unix_socket": "/run/mysqld/mysqld.sock"}
+]
diff --git a/ci/test_mysql.py b/ci/test_mysql.py
new file mode 100644
index 000000000..b97978a27
--- /dev/null
+++ b/ci/test_mysql.py
@@ -0,0 +1,47 @@
+# This is an example test settings file for use with the Django test suite.
+#
+# The 'sqlite3' backend requires only the ENGINE setting (an in-
+# memory database will be used). All other backends will require a
+# NAME and potentially authentication information. See the
+# following section in the docs for more information:
+#
+# https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/unit-tests/
+#
+# The different databases that Django supports behave differently in certain
+# situations, so it is recommended to run the test suite against as many
+# database backends as possible. You may want to create a separate settings
+# file for each of the backends you test against.
+
+import pymysql
+
+pymysql.install_as_MySQLdb()
+
+DATABASES = {
+ "default": {
+ "ENGINE": "django.db.backends.mysql",
+ "NAME": "django_default",
+ "HOST": "127.0.0.1",
+ "USER": "scott",
+ "PASSWORD": "tiger",
+ "TEST": {"CHARSET": "utf8mb3", "COLLATION": "utf8mb3_general_ci"},
+ },
+ "other": {
+ "ENGINE": "django.db.backends.mysql",
+ "NAME": "django_other",
+ "HOST": "127.0.0.1",
+ "USER": "scott",
+ "PASSWORD": "tiger",
+ "TEST": {"CHARSET": "utf8mb3", "COLLATION": "utf8mb3_general_ci"},
+ },
+}
+
+SECRET_KEY = "django_tests_secret_key"
+
+# Use a fast hasher to speed up tests.
+PASSWORD_HASHERS = [
+ "django.contrib.auth.hashers.MD5PasswordHasher",
+]
+
+DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
+
+USE_TZ = False
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 000000000..919adf200
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,7 @@
+# https://docs.codecov.com/docs/common-recipe-list
+coverage:
+ status:
+ project:
+ default:
+ target: auto
+ threshold: 3%
diff --git a/docs/Makefile b/docs/Makefile
index d37255520..c1240d2ba 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -74,30 +74,6 @@ json:
@echo
@echo "Build finished; now you can process the JSON files."
-htmlhelp:
- $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
- @echo
- @echo "Build finished; now you can run HTML Help Workshop with the" \
- ".hhp project file in $(BUILDDIR)/htmlhelp."
-
-qthelp:
- $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
- @echo
- @echo "Build finished; now you can run "qcollectiongenerator" with the" \
- ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
- @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PyMySQL.qhcp"
- @echo "To view the help file:"
- @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PyMySQL.qhc"
-
-devhelp:
- $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
- @echo
- @echo "Build finished."
- @echo "To view the help file:"
- @echo "# mkdir -p $$HOME/.local/share/devhelp/PyMySQL"
- @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PyMySQL"
- @echo "# devhelp"
-
epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
diff --git a/docs/make.bat b/docs/make.bat
deleted file mode 100644
index dcd4287c6..000000000
--- a/docs/make.bat
+++ /dev/null
@@ -1,242 +0,0 @@
-@ECHO OFF
-
-REM Command file for Sphinx documentation
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build
-)
-set BUILDDIR=build
-set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% source
-set I18NSPHINXOPTS=%SPHINXOPTS% source
-if NOT "%PAPER%" == "" (
- set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
- set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
-)
-
-if "%1" == "" goto help
-
-if "%1" == "help" (
- :help
- echo.Please use `make ^` where ^ is one of
- echo. html to make standalone HTML files
- echo. dirhtml to make HTML files named index.html in directories
- echo. singlehtml to make a single large HTML file
- echo. pickle to make pickle files
- echo. json to make JSON files
- echo. htmlhelp to make HTML files and a HTML help project
- echo. qthelp to make HTML files and a qthelp project
- echo. devhelp to make HTML files and a Devhelp project
- echo. epub to make an epub
- echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
- echo. text to make text files
- echo. man to make manual pages
- echo. texinfo to make Texinfo files
- echo. gettext to make PO message catalogs
- echo. changes to make an overview over all changed/added/deprecated items
- echo. xml to make Docutils-native XML files
- echo. pseudoxml to make pseudoxml-XML files for display purposes
- echo. linkcheck to check all external links for integrity
- echo. doctest to run all doctests embedded in the documentation if enabled
- goto end
-)
-
-if "%1" == "clean" (
- for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
- del /q /s %BUILDDIR%\*
- goto end
-)
-
-
-%SPHINXBUILD% 2> nul
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.If you don't have Sphinx installed, grab it from
- echo.http://sphinx-doc.org/
- exit /b 1
-)
-
-if "%1" == "html" (
- %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/html.
- goto end
-)
-
-if "%1" == "dirhtml" (
- %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
- goto end
-)
-
-if "%1" == "singlehtml" (
- %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
- goto end
-)
-
-if "%1" == "pickle" (
- %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the pickle files.
- goto end
-)
-
-if "%1" == "json" (
- %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the JSON files.
- goto end
-)
-
-if "%1" == "htmlhelp" (
- %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run HTML Help Workshop with the ^
-.hhp project file in %BUILDDIR%/htmlhelp.
- goto end
-)
-
-if "%1" == "qthelp" (
- %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run "qcollectiongenerator" with the ^
-.qhcp project file in %BUILDDIR%/qthelp, like this:
- echo.^> qcollectiongenerator %BUILDDIR%\qthelp\PyMySQL.qhcp
- echo.To view the help file:
- echo.^> assistant -collectionFile %BUILDDIR%\qthelp\PyMySQL.ghc
- goto end
-)
-
-if "%1" == "devhelp" (
- %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished.
- goto end
-)
-
-if "%1" == "epub" (
- %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The epub file is in %BUILDDIR%/epub.
- goto end
-)
-
-if "%1" == "latex" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "latexpdf" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- cd %BUILDDIR%/latex
- make all-pdf
- cd %BUILDDIR%/..
- echo.
- echo.Build finished; the PDF files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "latexpdfja" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- cd %BUILDDIR%/latex
- make all-pdf-ja
- cd %BUILDDIR%/..
- echo.
- echo.Build finished; the PDF files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "text" (
- %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The text files are in %BUILDDIR%/text.
- goto end
-)
-
-if "%1" == "man" (
- %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The manual pages are in %BUILDDIR%/man.
- goto end
-)
-
-if "%1" == "texinfo" (
- %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
- goto end
-)
-
-if "%1" == "gettext" (
- %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
- goto end
-)
-
-if "%1" == "changes" (
- %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
- if errorlevel 1 exit /b 1
- echo.
- echo.The overview file is in %BUILDDIR%/changes.
- goto end
-)
-
-if "%1" == "linkcheck" (
- %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
- if errorlevel 1 exit /b 1
- echo.
- echo.Link check complete; look for any errors in the above output ^
-or in %BUILDDIR%/linkcheck/output.txt.
- goto end
-)
-
-if "%1" == "doctest" (
- %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
- if errorlevel 1 exit /b 1
- echo.
- echo.Testing of doctests in the sources finished, look at the ^
-results in %BUILDDIR%/doctest/output.txt.
- goto end
-)
-
-if "%1" == "xml" (
- %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The XML files are in %BUILDDIR%/xml.
- goto end
-)
-
-if "%1" == "pseudoxml" (
- %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
- goto end
-)
-
-:end
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 000000000..014066235
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,2 @@
+sphinx~=7.2
+sphinx-rtd-theme~=2.0.0
diff --git a/docs/source/conf.py b/docs/source/conf.py
index bbadcbed1..158d0d12f 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-#
# PyMySQL documentation build configuration file, created by
# sphinx-quickstart on Tue May 17 12:01:11 2016.
#
@@ -18,55 +16,54 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../../'))
+sys.path.insert(0, os.path.abspath("../../"))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
-#needs_sphinx = '1.0'
+# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.intersphinx',
+ "sphinx.ext.autodoc",
]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The suffix of source filenames.
-source_suffix = '.rst'
+source_suffix = ".rst"
# The encoding of source files.
-#source_encoding = 'utf-8-sig'
+# source_encoding = 'utf-8-sig'
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# General information about the project.
-project = u'PyMySQL'
-copyright = u'2016, Yutaka Matsubara and GitHub contributors'
+project = "PyMySQL"
+copyright = "2023, Inada Naoki and GitHub contributors"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
-version = '0.7'
+version = "0.7"
# The full version, including alpha/beta/rc tags.
-release = '0.7.2'
+release = "0.7.2"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
-#language = None
+# language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
-#today = ''
+# today = ''
# Else, today_fmt is used as the format for a strftime call.
-#today_fmt = '%B %d, %Y'
+# today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
@@ -74,154 +71,157 @@
# The reST default role (used for this markup: `text`) to use for all
# documents.
-#default_role = None
+# default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
-#add_function_parentheses = True
+# add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
-#add_module_names = True
+# add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
-#show_authors = False
+# show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting.
-#modindex_common_prefix = []
+# modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents.
-#keep_warnings = False
+# keep_warnings = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'default'
+html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
-#html_theme_options = {}
+# html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
-#html_theme_path = []
+# html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
-#html_title = None
+# html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
-#html_short_title = None
+# html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-#html_logo = None
+# html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
-#html_favicon = None
+# html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation.
-#html_extra_path = []
+# html_extra_path = []
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
-#html_last_updated_fmt = '%b %d, %Y'
+# html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
-#html_use_smartypants = True
+# html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
-#html_sidebars = {}
+# html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
-#html_additional_pages = {}
+# html_additional_pages = {}
# If false, no module index is generated.
-#html_domain_indices = True
+# html_domain_indices = True
# If false, no index is generated.
-#html_use_index = True
+# html_use_index = True
# If true, the index is split into individual pages for each letter.
-#html_split_index = False
+# html_split_index = False
# If true, links to the reST sources are added to the pages.
-#html_show_sourcelink = True
+# html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
-#html_show_sphinx = True
+# html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
-#html_show_copyright = True
+# html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
-#html_use_opensearch = ''
+# html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
-#html_file_suffix = None
+# html_file_suffix = None
# Output file base name for HTML help builder.
-htmlhelp_basename = 'PyMySQLdoc'
+htmlhelp_basename = "PyMySQLdoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
-# The paper size ('letterpaper' or 'a4paper').
-#'papersize': 'letterpaper',
-
-# The font size ('10pt', '11pt' or '12pt').
-#'pointsize': '10pt',
-
-# Additional stuff for the LaTeX preamble.
-#'preamble': '',
+ # The paper size ('letterpaper' or 'a4paper').
+ #'papersize': 'letterpaper',
+ # The font size ('10pt', '11pt' or '12pt').
+ #'pointsize': '10pt',
+ # Additional stuff for the LaTeX preamble.
+ #'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- ('index', 'PyMySQL.tex', u'PyMySQL Documentation',
- u'Yutaka Matsubara and GitHub contributors', 'manual'),
+ (
+ "index",
+ "PyMySQL.tex",
+ "PyMySQL Documentation",
+ "Yutaka Matsubara and GitHub contributors",
+ "manual",
+ ),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
-#latex_logo = None
+# latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
-#latex_use_parts = False
+# latex_use_parts = False
# If true, show page references after internal links.
-#latex_show_pagerefs = False
+# latex_show_pagerefs = False
# If true, show URL addresses after external links.
-#latex_show_urls = False
+# latex_show_urls = False
# Documents to append as an appendix to all manuals.
-#latex_appendices = []
+# latex_appendices = []
# If false, no module index is generated.
-#latex_domain_indices = True
+# latex_domain_indices = True
# -- Options for manual page output ---------------------------------------
@@ -229,12 +229,17 @@
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
- ('index', 'pymysql', u'PyMySQL Documentation',
- [u'Yutaka Matsubara and GitHub contributors'], 1)
+ (
+ "index",
+ "pymysql",
+ "PyMySQL Documentation",
+ ["Yutaka Matsubara and GitHub contributors"],
+ 1,
+ )
]
# If true, show URL addresses after external links.
-#man_show_urls = False
+# man_show_urls = False
# -- Options for Texinfo output -------------------------------------------
@@ -243,23 +248,29 @@
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
- ('index', 'PyMySQL', u'PyMySQL Documentation',
- u'Yutaka Matsubara and GitHub contributors', 'PyMySQL', 'One line description of project.',
- 'Miscellaneous'),
+ (
+ "index",
+ "PyMySQL",
+ "PyMySQL Documentation",
+ "Yutaka Matsubara and GitHub contributors",
+ "PyMySQL",
+ "One line description of project.",
+ "Miscellaneous",
+ ),
]
# Documents to append as an appendix to all manuals.
-#texinfo_appendices = []
+# texinfo_appendices = []
# If false, no module index is generated.
-#texinfo_domain_indices = True
+# texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
-#texinfo_show_urls = 'footnote'
+# texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu.
-#texinfo_no_detailmenu = False
+# texinfo_no_detailmenu = False
# Example configuration for intersphinx: refer to the Python standard library.
-intersphinx_mapping = {'http://docs.python.org/': None}
+intersphinx_mapping = {"http://docs.python.org/": None}
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 97633f1aa..e64b64238 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,5 +1,5 @@
-Welcome to PyMySQL's documentation!
-===================================
+PyMySQL documentation
+=====================
.. toctree::
:maxdepth: 2
diff --git a/docs/source/user/development.rst b/docs/source/user/development.rst
index 39c40e1a7..1f8a2637f 100644
--- a/docs/source/user/development.rst
+++ b/docs/source/user/development.rst
@@ -22,17 +22,13 @@ If you would like to run the test suite, create a database for testing like this
mysql -e 'create database test_pymysql DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'
mysql -e 'create database test_pymysql2 DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'
-Then, copy the file ``.travis/database.json`` to ``pymysql/tests/databases.json``
+Then, copy the file ``ci/database.json`` to ``pymysql/tests/databases.json``
and edit the new file to match your MySQL configuration::
- $ cp .travis/database.json pymysql/tests/databases.json
+ $ cp ci/database.json pymysql/tests/databases.json
$ $EDITOR pymysql/tests/databases.json
To run all the tests, execute the script ``runtests.py``::
- $ python runtests.py
-
-A ``tox.ini`` file is also provided for conveniently running tests on multiple
-Python versions::
-
- $ tox
+ $ pip install pytest
+ $ pytest -v pymysql
diff --git a/docs/source/user/examples.rst b/docs/source/user/examples.rst
index 87af40c37..3946db9b9 100644
--- a/docs/source/user/examples.rst
+++ b/docs/source/user/examples.rst
@@ -18,7 +18,7 @@ The following examples make use of a simple table
`email` varchar(255) COLLATE utf8_bin NOT NULL,
`password` varchar(255) COLLATE utf8_bin NOT NULL,
PRIMARY KEY (`id`)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
AUTO_INCREMENT=1 ;
@@ -30,11 +30,11 @@ The following examples make use of a simple table
connection = pymysql.connect(host='localhost',
user='user',
password='passwd',
- db='db',
+ database='db',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor)
- try:
+ with connection:
with connection.cursor() as cursor:
# Create a new record
sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)"
@@ -50,11 +50,10 @@ The following examples make use of a simple table
cursor.execute(sql, ('webmaster@python.org',))
result = cursor.fetchone()
print(result)
- finally:
- connection.close()
+
This example will print:
.. code:: python
- {'password': 'very-secret', 'id': 1}
+ {'id': 1, 'password': 'very-secret'}
diff --git a/docs/source/user/installation.rst b/docs/source/user/installation.rst
index e3bfe84d0..9313f14d3 100644
--- a/docs/source/user/installation.rst
+++ b/docs/source/user/installation.rst
@@ -6,24 +6,27 @@ Installation
The last stable release is available on PyPI and can be installed with ``pip``::
- $ pip install PyMySQL
+ $ python3 -m pip install PyMySQL
+
+To use "sha256_password" or "caching_sha2_password" for authenticate,
+you need to install additional dependency::
+
+ $ python3 -m pip install PyMySQL[rsa]
Requirements
-------------
* Python -- one of the following:
- - CPython_ >= 2.6 or >= 3.3
- - PyPy_ >= 4.0
- - IronPython_ 2.7
+ - CPython_ >= 3.7
+ - Latest PyPy_ 3
* MySQL Server -- one of the following:
- - MySQL_ >= 4.1 (tested with only 5.5~)
- - MariaDB_ >= 5.1
+ - MySQL_ >= 5.7
+ - MariaDB_ >= 10.3
.. _CPython: http://www.python.org/
.. _PyPy: http://pypy.org/
-.. _IronPython: http://ironpython.net/
.. _MySQL: http://www.mysql.com/
.. _MariaDB: https://mariadb.org/
diff --git a/example.py b/example.py
index 68582138d..c12f103b5 100644
--- a/example.py
+++ b/example.py
@@ -1,16 +1,13 @@
#!/usr/bin/env python
-from __future__ import print_function
-
import pymysql
-conn = pymysql.connect(host='localhost', port=3306, user='root', passwd='', db='mysql')
+conn = pymysql.connect(host="localhost", port=3306, user="root", passwd="", db="mysql")
cur = conn.cursor()
cur.execute("SELECT Host,User FROM user")
print(cur.description)
-
print()
for row in cur:
diff --git a/pymysql/__init__.py b/pymysql/__init__.py
index b79b4b83e..bbf9023ef 100644
--- a/pymysql/__init__.py
+++ b/pymysql/__init__.py
@@ -21,32 +21,65 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
+
import sys
-from ._compat import PY2
from .constants import FIELD_TYPE
-from .converters import escape_dict, escape_sequence, escape_string
from .err import (
- Warning, Error, InterfaceError, DataError,
- DatabaseError, OperationalError, IntegrityError, InternalError,
- NotSupportedError, ProgrammingError, MySQLError)
+ Warning,
+ Error,
+ InterfaceError,
+ DataError,
+ DatabaseError,
+ OperationalError,
+ IntegrityError,
+ InternalError,
+ NotSupportedError,
+ ProgrammingError,
+ MySQLError,
+)
from .times import (
- Date, Time, Timestamp,
- DateFromTicks, TimeFromTicks, TimestampFromTicks)
+ Date,
+ Time,
+ Timestamp,
+ DateFromTicks,
+ TimeFromTicks,
+ TimestampFromTicks,
+)
+
+# PyMySQL version.
+# Used by setuptools and connection_attrs
+VERSION = (1, 1, 1, "final", 1)
+VERSION_STRING = "1.1.1"
+
+### for mysqlclient compatibility
+### Django checks mysqlclient version.
+version_info = (1, 4, 6, "final", 1)
+__version__ = "1.4.6"
+
+
+def get_client_info(): # for MySQLdb compatibility
+ return __version__
+
+
+def install_as_MySQLdb():
+ """
+ After this function is called, any application that imports MySQLdb
+ will unwittingly actually use pymysql.
+ """
+ sys.modules["MySQLdb"] = sys.modules["pymysql"]
+
+# end of mysqlclient compatibility code
-VERSION = (0, 9, 2, None)
-if VERSION[3] is not None:
- VERSION_STRING = "%d.%d.%d_%s" % VERSION
-else:
- VERSION_STRING = "%d.%d.%d" % VERSION[:3]
threadsafety = 1
apilevel = "2.0"
paramstyle = "pyformat"
+from . import connections # noqa: E402
-class DBAPISet(frozenset):
+class DBAPISet(frozenset):
def __ne__(self, other):
if isinstance(other, set):
return frozenset.__ne__(self, other)
@@ -63,79 +96,88 @@ def __hash__(self):
return frozenset.__hash__(self)
-STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
- FIELD_TYPE.VAR_STRING])
-BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB,
- FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB])
-NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT,
- FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG,
- FIELD_TYPE.TINY, FIELD_TYPE.YEAR])
-DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE])
-TIME = DBAPISet([FIELD_TYPE.TIME])
+STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING])
+BINARY = DBAPISet(
+ [
+ FIELD_TYPE.BLOB,
+ FIELD_TYPE.LONG_BLOB,
+ FIELD_TYPE.MEDIUM_BLOB,
+ FIELD_TYPE.TINY_BLOB,
+ ]
+)
+NUMBER = DBAPISet(
+ [
+ FIELD_TYPE.DECIMAL,
+ FIELD_TYPE.DOUBLE,
+ FIELD_TYPE.FLOAT,
+ FIELD_TYPE.INT24,
+ FIELD_TYPE.LONG,
+ FIELD_TYPE.LONGLONG,
+ FIELD_TYPE.TINY,
+ FIELD_TYPE.YEAR,
+ ]
+)
+DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE])
+TIME = DBAPISet([FIELD_TYPE.TIME])
TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
-DATETIME = TIMESTAMP
-ROWID = DBAPISet()
+DATETIME = TIMESTAMP
+ROWID = DBAPISet()
def Binary(x):
"""Return x as a binary type."""
- if PY2:
- return bytearray(x)
- else:
- return bytes(x)
-
-
-def Connect(*args, **kwargs):
- """
- Connect to the database; see connections.Connection.__init__() for
- more information.
- """
- from .connections import Connection
- return Connection(*args, **kwargs)
-
-from . import connections as _orig_conn
-if _orig_conn.Connection.__init__.__doc__ is not None:
- Connect.__doc__ = _orig_conn.Connection.__init__.__doc__
-del _orig_conn
-
-
-def get_client_info(): # for MySQLdb compatibility
- version = VERSION
- if VERSION[3] is None:
- version = VERSION[:3]
- return '.'.join(map(str, version))
-
-connect = Connection = Connect
-
-# we include a doctored version_info here for MySQLdb compatibility
-version_info = (1, 3, 12, "final", 0)
+ return bytes(x)
-NULL = "NULL"
-
-__version__ = get_client_info()
def thread_safe():
return True # match MySQLdb.thread_safe()
-def install_as_MySQLdb():
- """
- After this function is called, any application that imports MySQLdb or
- _mysql will unwittingly actually use pymysql.
- """
- sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
+
+Connect = connect = Connection = connections.Connection
+NULL = "NULL"
__all__ = [
- 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
- 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
- 'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError',
- 'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER',
- 'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError',
- 'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect',
- 'connections', 'constants', 'converters', 'cursors',
- 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
- 'paramstyle', 'threadsafety', 'version_info',
-
+ "BINARY",
+ "Binary",
+ "Connect",
+ "Connection",
+ "DATE",
+ "Date",
+ "Time",
+ "Timestamp",
+ "DateFromTicks",
+ "TimeFromTicks",
+ "TimestampFromTicks",
+ "DataError",
+ "DatabaseError",
+ "Error",
+ "FIELD_TYPE",
+ "IntegrityError",
+ "InterfaceError",
+ "InternalError",
+ "MySQLError",
+ "NULL",
+ "NUMBER",
+ "NotSupportedError",
+ "DBAPISet",
+ "OperationalError",
+ "ProgrammingError",
+ "ROWID",
+ "STRING",
+ "TIME",
+ "TIMESTAMP",
+ "Warning",
+ "apilevel",
+ "connect",
+ "connections",
+ "constants",
+ "converters",
+ "cursors",
+ "get_client_info",
+ "paramstyle",
+ "threadsafety",
+ "version_info",
"install_as_MySQLdb",
- "NULL", "__version__",
+ "__version__",
]
diff --git a/pymysql/_auth.py b/pymysql/_auth.py
index bbb742d3a..8ce744fb5 100644
--- a/pymysql/_auth.py
+++ b/pymysql/_auth.py
@@ -1,22 +1,26 @@
"""
Implements auth methods
"""
-from ._compat import text_type, PY2
-from .constants import CLIENT
+
from .err import OperationalError
-from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import serialization, hashes
-from cryptography.hazmat.primitives.asymmetric import padding
+
+try:
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives import serialization, hashes
+ from cryptography.hazmat.primitives.asymmetric import padding
+
+ _have_cryptography = True
+except ImportError:
+ _have_cryptography = False
from functools import partial
import hashlib
-import struct
DEBUG = False
SCRAMBLE_LENGTH = 20
-sha1_new = partial(hashlib.new, 'sha1')
+sha1_new = partial(hashlib.new, "sha1")
# mysql_native_password
@@ -26,7 +30,7 @@
def scramble_native_password(password, message):
"""Scramble used for mysql_native_password"""
if not password:
- return b''
+ return b""
stage1 = sha1_new(password).digest()
stage2 = sha1_new(stage1).digest()
@@ -39,8 +43,6 @@ def scramble_native_password(password, message):
def _my_crypt(message1, message2):
result = bytearray(message1)
- if PY2:
- message2 = bytearray(message2)
for i in range(len(result)):
result[i] ^= message2[i]
@@ -48,60 +50,67 @@ def _my_crypt(message1, message2):
return bytes(result)
-# old_passwords support ported from libmysql/password.c
-# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
+# MariaDB's client_ed25519-plugin
+# https://mariadb.com/kb/en/library/connection/#client_ed25519-plugin
-SCRAMBLE_LENGTH_323 = 8
+_nacl_bindings = False
-class RandStruct_323(object):
+def _init_nacl():
+ global _nacl_bindings
+ try:
+ from nacl import bindings
- def __init__(self, seed1, seed2):
- self.max_value = 0x3FFFFFFF
- self.seed1 = seed1 % self.max_value
- self.seed2 = seed2 % self.max_value
+ _nacl_bindings = bindings
+ except ImportError:
+ raise RuntimeError(
+ "'pynacl' package is required for ed25519_password auth method"
+ )
- def my_rnd(self):
- self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
- self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
- return float(self.seed1) / float(self.max_value)
+def _scalar_clamp(s32):
+ ba = bytearray(s32)
+ ba0 = bytes(bytearray([ba[0] & 248]))
+ ba31 = bytes(bytearray([(ba[31] & 127) | 64]))
+ return ba0 + bytes(s32[1:31]) + ba31
-def scramble_old_password(password, message):
- """Scramble for old_password"""
- hash_pass = _hash_password_323(password)
- hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
- hash_pass_n = struct.unpack(">LL", hash_pass)
- hash_message_n = struct.unpack(">LL", hash_message)
- rand_st = RandStruct_323(
- hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1]
- )
- outbuf = io.BytesIO()
- for _ in range(min(SCRAMBLE_LENGTH_323, len(message))):
- outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
- extra = int2byte(int(rand_st.my_rnd() * 31))
- out = outbuf.getvalue()
- outbuf = io.BytesIO()
- for c in out:
- outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
- return outbuf.getvalue()
+def ed25519_password(password, scramble):
+ """Sign a random scramble with elliptic curve Ed25519.
+
+ Secret and public key are derived from password.
+ """
+ # variable names based on rfc8032 section-5.1.6
+ #
+ if not _nacl_bindings:
+ _init_nacl()
+
+ # h = SHA512(password)
+ h = hashlib.sha512(password).digest()
+
+ # s = prune(first_half(h))
+ s = _scalar_clamp(h[:32])
+ # r = SHA512(second_half(h) || M)
+ r = hashlib.sha512(h[32:] + scramble).digest()
-def _hash_password_323(password):
- nr = 1345345333
- add = 7
- nr2 = 0x12345671
+ # R = encoded point [r]B
+ r = _nacl_bindings.crypto_core_ed25519_scalar_reduce(r)
+ R = _nacl_bindings.crypto_scalarmult_ed25519_base_noclamp(r)
- # x in py3 is numbers, p27 is chars
- for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
- nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
- nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
- add = (add + c) & 0xFFFFFFFF
+ # A = encoded point [s]B
+ A = _nacl_bindings.crypto_scalarmult_ed25519_base_noclamp(s)
- r1 = nr & ((1 << 31) - 1) # kill sign bits
- r2 = nr2 & ((1 << 31) - 1)
- return struct.pack(">LL", r1, r2)
+ # k = SHA512(R || A || M)
+ k = hashlib.sha512(R + A + scramble).digest()
+
+ # S = (k * s + r) mod L
+ k = _nacl_bindings.crypto_core_ed25519_scalar_reduce(k)
+ ks = _nacl_bindings.crypto_core_ed25519_scalar_mul(k, s)
+ S = _nacl_bindings.crypto_core_ed25519_scalar_add(ks, r)
+
+ # signature = R || S
+ return R + S
# sha256_password
@@ -115,8 +124,11 @@ def _roundtrip(conn, send_data):
def _xor_password(password, salt):
+ # Trailing NUL character will be added in Auth Switch Request.
+ # See https://github.com/mysql/mysql-server/blob/7d10c82196c8e45554f27c00681474a9fb86d137/sql/auth/sha2_password.cc#L939-L945
+ salt = salt[:SCRAMBLE_LENGTH]
password_bytes = bytearray(password)
- salt = bytearray(salt) # for PY2 compat.
+ # salt = bytearray(salt) # for PY2 compat.
salt_len = len(salt)
for i in range(len(password_bytes)):
password_bytes[i] ^= salt[i % salt_len]
@@ -128,7 +140,12 @@ def sha2_rsa_encrypt(password, salt, public_key):
Used for sha256_password and caching_sha2_password.
"""
- message = _xor_password(password + b'\0', salt)
+ if not _have_cryptography:
+ raise RuntimeError(
+ "'cryptography' package is required for sha256_password or"
+ + " caching_sha2_password auth methods"
+ )
+ message = _xor_password(password + b"\0", salt)
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
return rsa_key.encrypt(
message,
@@ -144,7 +161,7 @@ def sha256_password_auth(conn, pkt):
if conn._secure:
if DEBUG:
print("sha256: Sending plain password")
- data = conn.password + b'\0'
+ data = conn.password + b"\0"
return _roundtrip(conn, data)
if pkt.is_auth_switch_request():
@@ -153,12 +170,12 @@ def sha256_password_auth(conn, pkt):
# Request server public key
if DEBUG:
print("sha256: Requesting server public key")
- pkt = _roundtrip(conn, b'\1')
+ pkt = _roundtrip(conn, b"\1")
if pkt.is_extra_auth_data():
conn.server_public_key = pkt._data[1:]
if DEBUG:
- print("Received public key:\n", conn.server_public_key.decode('ascii'))
+ print("Received public key:\n", conn.server_public_key.decode("ascii"))
if conn.password:
if not conn.server_public_key:
@@ -166,7 +183,7 @@ def sha256_password_auth(conn, pkt):
data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
else:
- data = b''
+ data = b""
return _roundtrip(conn, data)
@@ -178,15 +195,13 @@ def scramble_caching_sha2(password, nonce):
XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
"""
if not password:
- return b''
+ return b""
p1 = hashlib.sha256(password).digest()
p2 = hashlib.sha256(p1).digest()
p3 = hashlib.sha256(p2 + nonce).digest()
res = bytearray(p1)
- if PY2:
- p3 = bytearray(p3)
for i in range(len(p3)):
res[i] ^= p3[i]
@@ -196,7 +211,7 @@ def scramble_caching_sha2(password, nonce):
def caching_sha2_password_auth(conn, pkt):
# No password fast path
if not conn.password:
- return _roundtrip(conn, b'')
+ return _roundtrip(conn, b"")
if pkt.is_auth_switch_request():
# Try from fast auth
@@ -228,7 +243,7 @@ def caching_sha2_password_auth(conn, pkt):
return pkt
if n != 4:
- raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n)
+ raise OperationalError("caching sha2: Unknown result for fast auth: %s" % n)
if DEBUG:
print("caching sha2: Trying full auth...")
@@ -236,10 +251,10 @@ def caching_sha2_password_auth(conn, pkt):
if conn._secure:
if DEBUG:
print("caching sha2: Sending plain password via secure connection")
- return _roundtrip(conn, conn.password + b'\0')
+ return _roundtrip(conn, conn.password + b"\0")
if not conn.server_public_key:
- pkt = _roundtrip(conn, b'\x02') # Request public key
+ pkt = _roundtrip(conn, b"\x02") # Request public key
if not pkt.is_extra_auth_data():
raise OperationalError(
"caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
@@ -247,7 +262,7 @@ def caching_sha2_password_auth(conn, pkt):
conn.server_public_key = pkt._data[1:]
if DEBUG:
- print(conn.server_public_key.decode('ascii'))
+ print(conn.server_public_key.decode("ascii"))
data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
pkt = _roundtrip(conn, data)
diff --git a/pymysql/_compat.py b/pymysql/_compat.py
deleted file mode 100644
index 252789ec4..000000000
--- a/pymysql/_compat.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import sys
-
-PY2 = sys.version_info[0] == 2
-PYPY = hasattr(sys, 'pypy_translation_info')
-JYTHON = sys.platform.startswith('java')
-IRONPYTHON = sys.platform == 'cli'
-CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
-
-if PY2:
- import __builtin__
- range_type = xrange
- text_type = unicode
- long_type = long
- str_type = basestring
- unichr = __builtin__.unichr
-else:
- range_type = range
- text_type = str
- long_type = int
- str_type = str
- unichr = chr
diff --git a/pymysql/_socketio.py b/pymysql/_socketio.py
deleted file mode 100644
index 6a11d42e4..000000000
--- a/pymysql/_socketio.py
+++ /dev/null
@@ -1,134 +0,0 @@
-"""
-SocketIO imported from socket module in Python 3.
-
-Copyright (c) 2001-2013 Python Software Foundation; All Rights Reserved.
-"""
-
-from socket import *
-import io
-import errno
-
-__all__ = ['SocketIO']
-
-EINTR = errno.EINTR
-_blocking_errnos = (errno.EAGAIN, errno.EWOULDBLOCK)
-
-class SocketIO(io.RawIOBase):
-
- """Raw I/O implementation for stream sockets.
-
- This class supports the makefile() method on sockets. It provides
- the raw I/O interface on top of a socket object.
- """
-
- # One might wonder why not let FileIO do the job instead. There are two
- # main reasons why FileIO is not adapted:
- # - it wouldn't work under Windows (where you can't used read() and
- # write() on a socket handle)
- # - it wouldn't work with socket timeouts (FileIO would ignore the
- # timeout and consider the socket non-blocking)
-
- # XXX More docs
-
- def __init__(self, sock, mode):
- if mode not in ("r", "w", "rw", "rb", "wb", "rwb"):
- raise ValueError("invalid mode: %r" % mode)
- io.RawIOBase.__init__(self)
- self._sock = sock
- if "b" not in mode:
- mode += "b"
- self._mode = mode
- self._reading = "r" in mode
- self._writing = "w" in mode
- self._timeout_occurred = False
-
- def readinto(self, b):
- """Read up to len(b) bytes into the writable buffer *b* and return
- the number of bytes read. If the socket is non-blocking and no bytes
- are available, None is returned.
-
- If *b* is non-empty, a 0 return value indicates that the connection
- was shutdown at the other end.
- """
- self._checkClosed()
- self._checkReadable()
- if self._timeout_occurred:
- raise IOError("cannot read from timed out object")
- while True:
- try:
- return self._sock.recv_into(b)
- except timeout:
- self._timeout_occurred = True
- raise
- except error as e:
- n = e.args[0]
- if n == EINTR:
- continue
- if n in _blocking_errnos:
- return None
- raise
-
- def write(self, b):
- """Write the given bytes or bytearray object *b* to the socket
- and return the number of bytes written. This can be less than
- len(b) if not all data could be written. If the socket is
- non-blocking and no bytes could be written None is returned.
- """
- self._checkClosed()
- self._checkWritable()
- try:
- return self._sock.send(b)
- except error as e:
- # XXX what about EINTR?
- if e.args[0] in _blocking_errnos:
- return None
- raise
-
- def readable(self):
- """True if the SocketIO is open for reading.
- """
- if self.closed:
- raise ValueError("I/O operation on closed socket.")
- return self._reading
-
- def writable(self):
- """True if the SocketIO is open for writing.
- """
- if self.closed:
- raise ValueError("I/O operation on closed socket.")
- return self._writing
-
- def seekable(self):
- """True if the SocketIO is open for seeking.
- """
- if self.closed:
- raise ValueError("I/O operation on closed socket.")
- return super().seekable()
-
- def fileno(self):
- """Return the file descriptor of the underlying socket.
- """
- self._checkClosed()
- return self._sock.fileno()
-
- @property
- def name(self):
- if not self.closed:
- return self.fileno()
- else:
- return -1
-
- @property
- def mode(self):
- return self._mode
-
- def close(self):
- """Close the SocketIO object. This doesn't close the underlying
- socket, except if all references to it have disappeared.
- """
- if self.closed:
- return
- io.RawIOBase.close(self)
- self._sock._decref_socketios()
- self._sock = None
-
diff --git a/pymysql/charset.py b/pymysql/charset.py
index 968376cfa..b1c1ca8b8 100644
--- a/pymysql/charset.py
+++ b/pymysql/charset.py
@@ -1,25 +1,29 @@
-MBLENGTH = {
- 8:1,
- 33:3,
- 88:2,
- 91:2
- }
+# Internal use only. Do not use directly.
+MBLENGTH = {8: 1, 33: 3, 88: 2, 91: 2}
-class Charset(object):
- def __init__(self, id, name, collation, is_default):
+
+class Charset:
+ def __init__(self, id, name, collation, is_default=False):
self.id, self.name, self.collation = id, name, collation
- self.is_default = is_default == 'Yes'
+ self.is_default = is_default
def __repr__(self):
- return "Charset(id=%s, name=%r, collation=%r)" % (
- self.id, self.name, self.collation)
+ return (
+ f"Charset(id={self.id}, name={self.name!r}, collation={self.collation!r})"
+ )
@property
def encoding(self):
name = self.name
- if name == 'utf8mb4':
- return 'utf8'
+ if name in ("utf8mb4", "utf8mb3"):
+ return "utf8"
+ if name == "latin1":
+ return "cp1252"
+ if name == "koi8r":
+ return "koi8_r"
+ if name == "koi8u":
+ return "koi8_u"
return name
@property
@@ -30,241 +34,183 @@ def is_binary(self):
class Charsets:
def __init__(self):
self._by_id = {}
+ self._by_name = {}
def add(self, c):
self._by_id[c.id] = c
+ if c.is_default:
+ self._by_name[c.name] = c
def by_id(self, id):
return self._by_id[id]
def by_name(self, name):
- name = name.lower()
- for c in self._by_id.values():
- if c.name == name and c.is_default:
- return c
+ if name == "utf8":
+ name = "utf8mb4"
+ return self._by_name.get(name.lower())
+
_charsets = Charsets()
+charset_by_name = _charsets.by_name
+charset_by_id = _charsets.by_id
+
"""
+TODO: update this script.
+
Generated with:
mysql -N -s -e "select id, character_set_name, collation_name, is_default
from information_schema.collations order by id;" | python -c "import sys
for l in sys.stdin.readlines():
- id, name, collation, is_default = l.split(chr(9))
- print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \
- % (id, name, collation, is_default.strip())
-"
-
+ id, name, collation, is_default = l.split(chr(9))
+ if is_default.strip() == "Yes":
+ print('_charsets.add(Charset(%s, \'%s\', \'%s\', True))' \
+ % (id, name, collation))
+ else:
+ print('_charsets.add(Charset(%s, \'%s\', \'%s\'))' \
+ % (id, name, collation, bool(is_default.strip()))
"""
-_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes'))
-_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', ''))
-_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes'))
-_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes'))
-_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', ''))
-_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes'))
-_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes'))
-_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes'))
-_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes'))
-_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes'))
-_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes'))
-_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes'))
-_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes'))
-_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', ''))
-_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', ''))
-_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes'))
-_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes'))
-_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes'))
-_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', ''))
-_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', ''))
-_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes'))
-_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', ''))
-_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes'))
-_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes'))
-_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes'))
-_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', ''))
-_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes'))
-_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', ''))
-_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes'))
-_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', ''))
-_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes'))
-_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes'))
-_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', ''))
-_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes'))
-_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes'))
-_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes'))
-_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes'))
-_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes'))
-_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes'))
-_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes'))
-_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', ''))
-_charsets.add(Charset(43, 'macce', 'macce_bin', ''))
-_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', ''))
-_charsets.add(Charset(45, 'utf8mb4', 'utf8mb4_general_ci', 'Yes'))
-_charsets.add(Charset(46, 'utf8mb4', 'utf8mb4_bin', ''))
-_charsets.add(Charset(47, 'latin1', 'latin1_bin', ''))
-_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', ''))
-_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', ''))
-_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', ''))
-_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes'))
-_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', ''))
-_charsets.add(Charset(53, 'macroman', 'macroman_bin', ''))
-_charsets.add(Charset(54, 'utf16', 'utf16_general_ci', 'Yes'))
-_charsets.add(Charset(55, 'utf16', 'utf16_bin', ''))
-_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes'))
-_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', ''))
-_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes'))
-_charsets.add(Charset(60, 'utf32', 'utf32_general_ci', 'Yes'))
-_charsets.add(Charset(61, 'utf32', 'utf32_bin', ''))
-_charsets.add(Charset(63, 'binary', 'binary', 'Yes'))
-_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', ''))
-_charsets.add(Charset(65, 'ascii', 'ascii_bin', ''))
-_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', ''))
-_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', ''))
-_charsets.add(Charset(68, 'cp866', 'cp866_bin', ''))
-_charsets.add(Charset(69, 'dec8', 'dec8_bin', ''))
-_charsets.add(Charset(70, 'greek', 'greek_bin', ''))
-_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', ''))
-_charsets.add(Charset(72, 'hp8', 'hp8_bin', ''))
-_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', ''))
-_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', ''))
-_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', ''))
-_charsets.add(Charset(77, 'latin2', 'latin2_bin', ''))
-_charsets.add(Charset(78, 'latin5', 'latin5_bin', ''))
-_charsets.add(Charset(79, 'latin7', 'latin7_bin', ''))
-_charsets.add(Charset(80, 'cp850', 'cp850_bin', ''))
-_charsets.add(Charset(81, 'cp852', 'cp852_bin', ''))
-_charsets.add(Charset(82, 'swe7', 'swe7_bin', ''))
-_charsets.add(Charset(83, 'utf8', 'utf8_bin', ''))
-_charsets.add(Charset(84, 'big5', 'big5_bin', ''))
-_charsets.add(Charset(85, 'euckr', 'euckr_bin', ''))
-_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', ''))
-_charsets.add(Charset(87, 'gbk', 'gbk_bin', ''))
-_charsets.add(Charset(88, 'sjis', 'sjis_bin', ''))
-_charsets.add(Charset(89, 'tis620', 'tis620_bin', ''))
-_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', ''))
-_charsets.add(Charset(91, 'ujis', 'ujis_bin', ''))
-_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes'))
-_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', ''))
-_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', ''))
-_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes'))
-_charsets.add(Charset(96, 'cp932', 'cp932_bin', ''))
-_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes'))
-_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', ''))
-_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', ''))
-_charsets.add(Charset(101, 'utf16', 'utf16_unicode_ci', ''))
-_charsets.add(Charset(102, 'utf16', 'utf16_icelandic_ci', ''))
-_charsets.add(Charset(103, 'utf16', 'utf16_latvian_ci', ''))
-_charsets.add(Charset(104, 'utf16', 'utf16_romanian_ci', ''))
-_charsets.add(Charset(105, 'utf16', 'utf16_slovenian_ci', ''))
-_charsets.add(Charset(106, 'utf16', 'utf16_polish_ci', ''))
-_charsets.add(Charset(107, 'utf16', 'utf16_estonian_ci', ''))
-_charsets.add(Charset(108, 'utf16', 'utf16_spanish_ci', ''))
-_charsets.add(Charset(109, 'utf16', 'utf16_swedish_ci', ''))
-_charsets.add(Charset(110, 'utf16', 'utf16_turkish_ci', ''))
-_charsets.add(Charset(111, 'utf16', 'utf16_czech_ci', ''))
-_charsets.add(Charset(112, 'utf16', 'utf16_danish_ci', ''))
-_charsets.add(Charset(113, 'utf16', 'utf16_lithuanian_ci', ''))
-_charsets.add(Charset(114, 'utf16', 'utf16_slovak_ci', ''))
-_charsets.add(Charset(115, 'utf16', 'utf16_spanish2_ci', ''))
-_charsets.add(Charset(116, 'utf16', 'utf16_roman_ci', ''))
-_charsets.add(Charset(117, 'utf16', 'utf16_persian_ci', ''))
-_charsets.add(Charset(118, 'utf16', 'utf16_esperanto_ci', ''))
-_charsets.add(Charset(119, 'utf16', 'utf16_hungarian_ci', ''))
-_charsets.add(Charset(120, 'utf16', 'utf16_sinhala_ci', ''))
-_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', ''))
-_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', ''))
-_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', ''))
-_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', ''))
-_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', ''))
-_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', ''))
-_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', ''))
-_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', ''))
-_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', ''))
-_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', ''))
-_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', ''))
-_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', ''))
-_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', ''))
-_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', ''))
-_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', ''))
-_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', ''))
-_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', ''))
-_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', ''))
-_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', ''))
-_charsets.add(Charset(147, 'ucs2', 'ucs2_sinhala_ci', ''))
-_charsets.add(Charset(159, 'ucs2', 'ucs2_general_mysql500_ci', ''))
-_charsets.add(Charset(160, 'utf32', 'utf32_unicode_ci', ''))
-_charsets.add(Charset(161, 'utf32', 'utf32_icelandic_ci', ''))
-_charsets.add(Charset(162, 'utf32', 'utf32_latvian_ci', ''))
-_charsets.add(Charset(163, 'utf32', 'utf32_romanian_ci', ''))
-_charsets.add(Charset(164, 'utf32', 'utf32_slovenian_ci', ''))
-_charsets.add(Charset(165, 'utf32', 'utf32_polish_ci', ''))
-_charsets.add(Charset(166, 'utf32', 'utf32_estonian_ci', ''))
-_charsets.add(Charset(167, 'utf32', 'utf32_spanish_ci', ''))
-_charsets.add(Charset(168, 'utf32', 'utf32_swedish_ci', ''))
-_charsets.add(Charset(169, 'utf32', 'utf32_turkish_ci', ''))
-_charsets.add(Charset(170, 'utf32', 'utf32_czech_ci', ''))
-_charsets.add(Charset(171, 'utf32', 'utf32_danish_ci', ''))
-_charsets.add(Charset(172, 'utf32', 'utf32_lithuanian_ci', ''))
-_charsets.add(Charset(173, 'utf32', 'utf32_slovak_ci', ''))
-_charsets.add(Charset(174, 'utf32', 'utf32_spanish2_ci', ''))
-_charsets.add(Charset(175, 'utf32', 'utf32_roman_ci', ''))
-_charsets.add(Charset(176, 'utf32', 'utf32_persian_ci', ''))
-_charsets.add(Charset(177, 'utf32', 'utf32_esperanto_ci', ''))
-_charsets.add(Charset(178, 'utf32', 'utf32_hungarian_ci', ''))
-_charsets.add(Charset(179, 'utf32', 'utf32_sinhala_ci', ''))
-_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', ''))
-_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', ''))
-_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', ''))
-_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', ''))
-_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', ''))
-_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', ''))
-_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', ''))
-_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', ''))
-_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', ''))
-_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', ''))
-_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', ''))
-_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', ''))
-_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', ''))
-_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', ''))
-_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', ''))
-_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', ''))
-_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', ''))
-_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', ''))
-_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', ''))
-_charsets.add(Charset(211, 'utf8', 'utf8_sinhala_ci', ''))
-_charsets.add(Charset(223, 'utf8', 'utf8_general_mysql500_ci', ''))
-_charsets.add(Charset(224, 'utf8mb4', 'utf8mb4_unicode_ci', ''))
-_charsets.add(Charset(225, 'utf8mb4', 'utf8mb4_icelandic_ci', ''))
-_charsets.add(Charset(226, 'utf8mb4', 'utf8mb4_latvian_ci', ''))
-_charsets.add(Charset(227, 'utf8mb4', 'utf8mb4_romanian_ci', ''))
-_charsets.add(Charset(228, 'utf8mb4', 'utf8mb4_slovenian_ci', ''))
-_charsets.add(Charset(229, 'utf8mb4', 'utf8mb4_polish_ci', ''))
-_charsets.add(Charset(230, 'utf8mb4', 'utf8mb4_estonian_ci', ''))
-_charsets.add(Charset(231, 'utf8mb4', 'utf8mb4_spanish_ci', ''))
-_charsets.add(Charset(232, 'utf8mb4', 'utf8mb4_swedish_ci', ''))
-_charsets.add(Charset(233, 'utf8mb4', 'utf8mb4_turkish_ci', ''))
-_charsets.add(Charset(234, 'utf8mb4', 'utf8mb4_czech_ci', ''))
-_charsets.add(Charset(235, 'utf8mb4', 'utf8mb4_danish_ci', ''))
-_charsets.add(Charset(236, 'utf8mb4', 'utf8mb4_lithuanian_ci', ''))
-_charsets.add(Charset(237, 'utf8mb4', 'utf8mb4_slovak_ci', ''))
-_charsets.add(Charset(238, 'utf8mb4', 'utf8mb4_spanish2_ci', ''))
-_charsets.add(Charset(239, 'utf8mb4', 'utf8mb4_roman_ci', ''))
-_charsets.add(Charset(240, 'utf8mb4', 'utf8mb4_persian_ci', ''))
-_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', ''))
-_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', ''))
-_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', ''))
-_charsets.add(Charset(244, 'utf8mb4', 'utf8mb4_german2_ci', ''))
-_charsets.add(Charset(245, 'utf8mb4', 'utf8mb4_croatian_ci', ''))
-_charsets.add(Charset(246, 'utf8mb4', 'utf8mb4_unicode_520_ci', ''))
-_charsets.add(Charset(247, 'utf8mb4', 'utf8mb4_vietnamese_ci', ''))
-
-
-charset_by_name = _charsets.by_name
-charset_by_id = _charsets.by_id
-
-def charset_to_encoding(name):
- """Convert MySQL's charset name to Python's codec name"""
- if name == 'utf8mb4':
- return 'utf8'
- return name
+_charsets.add(Charset(1, "big5", "big5_chinese_ci", True))
+_charsets.add(Charset(2, "latin2", "latin2_czech_cs"))
+_charsets.add(Charset(3, "dec8", "dec8_swedish_ci", True))
+_charsets.add(Charset(4, "cp850", "cp850_general_ci", True))
+_charsets.add(Charset(5, "latin1", "latin1_german1_ci"))
+_charsets.add(Charset(6, "hp8", "hp8_english_ci", True))
+_charsets.add(Charset(7, "koi8r", "koi8r_general_ci", True))
+_charsets.add(Charset(8, "latin1", "latin1_swedish_ci", True))
+_charsets.add(Charset(9, "latin2", "latin2_general_ci", True))
+_charsets.add(Charset(10, "swe7", "swe7_swedish_ci", True))
+_charsets.add(Charset(11, "ascii", "ascii_general_ci", True))
+_charsets.add(Charset(12, "ujis", "ujis_japanese_ci", True))
+_charsets.add(Charset(13, "sjis", "sjis_japanese_ci", True))
+_charsets.add(Charset(14, "cp1251", "cp1251_bulgarian_ci"))
+_charsets.add(Charset(15, "latin1", "latin1_danish_ci"))
+_charsets.add(Charset(16, "hebrew", "hebrew_general_ci", True))
+_charsets.add(Charset(18, "tis620", "tis620_thai_ci", True))
+_charsets.add(Charset(19, "euckr", "euckr_korean_ci", True))
+_charsets.add(Charset(20, "latin7", "latin7_estonian_cs"))
+_charsets.add(Charset(21, "latin2", "latin2_hungarian_ci"))
+_charsets.add(Charset(22, "koi8u", "koi8u_general_ci", True))
+_charsets.add(Charset(23, "cp1251", "cp1251_ukrainian_ci"))
+_charsets.add(Charset(24, "gb2312", "gb2312_chinese_ci", True))
+_charsets.add(Charset(25, "greek", "greek_general_ci", True))
+_charsets.add(Charset(26, "cp1250", "cp1250_general_ci", True))
+_charsets.add(Charset(27, "latin2", "latin2_croatian_ci"))
+_charsets.add(Charset(28, "gbk", "gbk_chinese_ci", True))
+_charsets.add(Charset(29, "cp1257", "cp1257_lithuanian_ci"))
+_charsets.add(Charset(30, "latin5", "latin5_turkish_ci", True))
+_charsets.add(Charset(31, "latin1", "latin1_german2_ci"))
+_charsets.add(Charset(32, "armscii8", "armscii8_general_ci", True))
+_charsets.add(Charset(33, "utf8mb3", "utf8mb3_general_ci", True))
+_charsets.add(Charset(34, "cp1250", "cp1250_czech_cs"))
+_charsets.add(Charset(36, "cp866", "cp866_general_ci", True))
+_charsets.add(Charset(37, "keybcs2", "keybcs2_general_ci", True))
+_charsets.add(Charset(38, "macce", "macce_general_ci", True))
+_charsets.add(Charset(39, "macroman", "macroman_general_ci", True))
+_charsets.add(Charset(40, "cp852", "cp852_general_ci", True))
+_charsets.add(Charset(41, "latin7", "latin7_general_ci", True))
+_charsets.add(Charset(42, "latin7", "latin7_general_cs"))
+_charsets.add(Charset(43, "macce", "macce_bin"))
+_charsets.add(Charset(44, "cp1250", "cp1250_croatian_ci"))
+_charsets.add(Charset(45, "utf8mb4", "utf8mb4_general_ci", True))
+_charsets.add(Charset(46, "utf8mb4", "utf8mb4_bin"))
+_charsets.add(Charset(47, "latin1", "latin1_bin"))
+_charsets.add(Charset(48, "latin1", "latin1_general_ci"))
+_charsets.add(Charset(49, "latin1", "latin1_general_cs"))
+_charsets.add(Charset(50, "cp1251", "cp1251_bin"))
+_charsets.add(Charset(51, "cp1251", "cp1251_general_ci", True))
+_charsets.add(Charset(52, "cp1251", "cp1251_general_cs"))
+_charsets.add(Charset(53, "macroman", "macroman_bin"))
+_charsets.add(Charset(57, "cp1256", "cp1256_general_ci", True))
+_charsets.add(Charset(58, "cp1257", "cp1257_bin"))
+_charsets.add(Charset(59, "cp1257", "cp1257_general_ci", True))
+_charsets.add(Charset(63, "binary", "binary", True))
+_charsets.add(Charset(64, "armscii8", "armscii8_bin"))
+_charsets.add(Charset(65, "ascii", "ascii_bin"))
+_charsets.add(Charset(66, "cp1250", "cp1250_bin"))
+_charsets.add(Charset(67, "cp1256", "cp1256_bin"))
+_charsets.add(Charset(68, "cp866", "cp866_bin"))
+_charsets.add(Charset(69, "dec8", "dec8_bin"))
+_charsets.add(Charset(70, "greek", "greek_bin"))
+_charsets.add(Charset(71, "hebrew", "hebrew_bin"))
+_charsets.add(Charset(72, "hp8", "hp8_bin"))
+_charsets.add(Charset(73, "keybcs2", "keybcs2_bin"))
+_charsets.add(Charset(74, "koi8r", "koi8r_bin"))
+_charsets.add(Charset(75, "koi8u", "koi8u_bin"))
+_charsets.add(Charset(76, "utf8mb3", "utf8mb3_tolower_ci"))
+_charsets.add(Charset(77, "latin2", "latin2_bin"))
+_charsets.add(Charset(78, "latin5", "latin5_bin"))
+_charsets.add(Charset(79, "latin7", "latin7_bin"))
+_charsets.add(Charset(80, "cp850", "cp850_bin"))
+_charsets.add(Charset(81, "cp852", "cp852_bin"))
+_charsets.add(Charset(82, "swe7", "swe7_bin"))
+_charsets.add(Charset(83, "utf8mb3", "utf8mb3_bin"))
+_charsets.add(Charset(84, "big5", "big5_bin"))
+_charsets.add(Charset(85, "euckr", "euckr_bin"))
+_charsets.add(Charset(86, "gb2312", "gb2312_bin"))
+_charsets.add(Charset(87, "gbk", "gbk_bin"))
+_charsets.add(Charset(88, "sjis", "sjis_bin"))
+_charsets.add(Charset(89, "tis620", "tis620_bin"))
+_charsets.add(Charset(91, "ujis", "ujis_bin"))
+_charsets.add(Charset(92, "geostd8", "geostd8_general_ci", True))
+_charsets.add(Charset(93, "geostd8", "geostd8_bin"))
+_charsets.add(Charset(94, "latin1", "latin1_spanish_ci"))
+_charsets.add(Charset(95, "cp932", "cp932_japanese_ci", True))
+_charsets.add(Charset(96, "cp932", "cp932_bin"))
+_charsets.add(Charset(97, "eucjpms", "eucjpms_japanese_ci", True))
+_charsets.add(Charset(98, "eucjpms", "eucjpms_bin"))
+_charsets.add(Charset(99, "cp1250", "cp1250_polish_ci"))
+_charsets.add(Charset(192, "utf8mb3", "utf8mb3_unicode_ci"))
+_charsets.add(Charset(193, "utf8mb3", "utf8mb3_icelandic_ci"))
+_charsets.add(Charset(194, "utf8mb3", "utf8mb3_latvian_ci"))
+_charsets.add(Charset(195, "utf8mb3", "utf8mb3_romanian_ci"))
+_charsets.add(Charset(196, "utf8mb3", "utf8mb3_slovenian_ci"))
+_charsets.add(Charset(197, "utf8mb3", "utf8mb3_polish_ci"))
+_charsets.add(Charset(198, "utf8mb3", "utf8mb3_estonian_ci"))
+_charsets.add(Charset(199, "utf8mb3", "utf8mb3_spanish_ci"))
+_charsets.add(Charset(200, "utf8mb3", "utf8mb3_swedish_ci"))
+_charsets.add(Charset(201, "utf8mb3", "utf8mb3_turkish_ci"))
+_charsets.add(Charset(202, "utf8mb3", "utf8mb3_czech_ci"))
+_charsets.add(Charset(203, "utf8mb3", "utf8mb3_danish_ci"))
+_charsets.add(Charset(204, "utf8mb3", "utf8mb3_lithuanian_ci"))
+_charsets.add(Charset(205, "utf8mb3", "utf8mb3_slovak_ci"))
+_charsets.add(Charset(206, "utf8mb3", "utf8mb3_spanish2_ci"))
+_charsets.add(Charset(207, "utf8mb3", "utf8mb3_roman_ci"))
+_charsets.add(Charset(208, "utf8mb3", "utf8mb3_persian_ci"))
+_charsets.add(Charset(209, "utf8mb3", "utf8mb3_esperanto_ci"))
+_charsets.add(Charset(210, "utf8mb3", "utf8mb3_hungarian_ci"))
+_charsets.add(Charset(211, "utf8mb3", "utf8mb3_sinhala_ci"))
+_charsets.add(Charset(212, "utf8mb3", "utf8mb3_german2_ci"))
+_charsets.add(Charset(213, "utf8mb3", "utf8mb3_croatian_ci"))
+_charsets.add(Charset(214, "utf8mb3", "utf8mb3_unicode_520_ci"))
+_charsets.add(Charset(215, "utf8mb3", "utf8mb3_vietnamese_ci"))
+_charsets.add(Charset(223, "utf8mb3", "utf8mb3_general_mysql500_ci"))
+_charsets.add(Charset(224, "utf8mb4", "utf8mb4_unicode_ci"))
+_charsets.add(Charset(225, "utf8mb4", "utf8mb4_icelandic_ci"))
+_charsets.add(Charset(226, "utf8mb4", "utf8mb4_latvian_ci"))
+_charsets.add(Charset(227, "utf8mb4", "utf8mb4_romanian_ci"))
+_charsets.add(Charset(228, "utf8mb4", "utf8mb4_slovenian_ci"))
+_charsets.add(Charset(229, "utf8mb4", "utf8mb4_polish_ci"))
+_charsets.add(Charset(230, "utf8mb4", "utf8mb4_estonian_ci"))
+_charsets.add(Charset(231, "utf8mb4", "utf8mb4_spanish_ci"))
+_charsets.add(Charset(232, "utf8mb4", "utf8mb4_swedish_ci"))
+_charsets.add(Charset(233, "utf8mb4", "utf8mb4_turkish_ci"))
+_charsets.add(Charset(234, "utf8mb4", "utf8mb4_czech_ci"))
+_charsets.add(Charset(235, "utf8mb4", "utf8mb4_danish_ci"))
+_charsets.add(Charset(236, "utf8mb4", "utf8mb4_lithuanian_ci"))
+_charsets.add(Charset(237, "utf8mb4", "utf8mb4_slovak_ci"))
+_charsets.add(Charset(238, "utf8mb4", "utf8mb4_spanish2_ci"))
+_charsets.add(Charset(239, "utf8mb4", "utf8mb4_roman_ci"))
+_charsets.add(Charset(240, "utf8mb4", "utf8mb4_persian_ci"))
+_charsets.add(Charset(241, "utf8mb4", "utf8mb4_esperanto_ci"))
+_charsets.add(Charset(242, "utf8mb4", "utf8mb4_hungarian_ci"))
+_charsets.add(Charset(243, "utf8mb4", "utf8mb4_sinhala_ci"))
+_charsets.add(Charset(244, "utf8mb4", "utf8mb4_german2_ci"))
+_charsets.add(Charset(245, "utf8mb4", "utf8mb4_croatian_ci"))
+_charsets.add(Charset(246, "utf8mb4", "utf8mb4_unicode_520_ci"))
+_charsets.add(Charset(247, "utf8mb4", "utf8mb4_vietnamese_ci"))
+_charsets.add(Charset(248, "gb18030", "gb18030_chinese_ci", True))
+_charsets.add(Charset(249, "gb18030", "gb18030_bin"))
+_charsets.add(Charset(250, "gb18030", "gb18030_unicode_520_ci"))
+_charsets.add(Charset(255, "utf8mb4", "utf8mb4_0900_ai_ci"))
diff --git a/pymysql/connections.py b/pymysql/connections.py
index 1e580d21d..3a04ddd68 100644
--- a/pymysql/connections.py
+++ b/pymysql/connections.py
@@ -1,12 +1,8 @@
# Python implementation of the MySQL client-server protocol
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
# Error codes:
-# http://dev.mysql.com/doc/refman/5.5/en/error-messages-client.html
-from __future__ import print_function
-from ._compat import PY2, range_type, text_type, str_type, JYTHON, IRONPYTHON
-
+# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
import errno
-import io
import os
import socket
import struct
@@ -17,19 +13,23 @@
from . import _auth
from .charset import charset_by_name, charset_by_id
-from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
+from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
from . import converters
from .cursors import Cursor
from .optionfile import Parser
from .protocol import (
- dump_packet, MysqlPacket, FieldDescriptorPacket, OKPacketWrapper,
- EOFPacketWrapper, LoadLocalPacketWrapper
+ dump_packet,
+ MysqlPacket,
+ FieldDescriptorPacket,
+ OKPacketWrapper,
+ EOFPacketWrapper,
+ LoadLocalPacketWrapper,
)
-from .util import byte2int, int2byte
from . import err, VERSION_STRING
try:
import ssl
+
SSL_ENABLED = True
except ImportError:
ssl = None
@@ -37,6 +37,7 @@
try:
import getpass
+
DEFAULT_USER = getpass.getuser()
del getpass
except (ImportError, KeyError):
@@ -45,36 +46,6 @@
DEBUG = False
-_py_version = sys.version_info[:2]
-
-if PY2:
- pass
-elif _py_version < (3, 6):
- # See http://bugs.python.org/issue24870
- _surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)]
-
- def _fast_surrogateescape(s):
- return s.decode('latin1').translate(_surrogateescape_table)
-else:
- def _fast_surrogateescape(s):
- return s.decode('ascii', 'surrogateescape')
-
-# socket.makefile() in Python 2 is not usable because very inefficient and
-# bad behavior about timeout.
-# XXX: ._socketio doesn't work under IronPython.
-if PY2 and not IRONPYTHON:
- # read method of file-like returned by sock.makefile() is very slow.
- # So we copy io-based one from Python 3.
- from ._socketio import SocketIO
-
- def _makefile(sock, mode):
- return io.BufferedReader(SocketIO(sock, mode))
-else:
- # socket.makefile in Python 3 is nice.
- def _makefile(sock, mode):
- return sock.makefile(mode)
-
-
TEXT_TYPES = {
FIELD_TYPE.BIT,
FIELD_TYPE.BLOB,
@@ -88,32 +59,36 @@ def _makefile(sock, mode):
}
-DEFAULT_CHARSET = 'utf8mb4' # TODO: change to utf8mb4
+DEFAULT_CHARSET = "utf8mb4"
-MAX_PACKET_LEN = 2**24-1
+MAX_PACKET_LEN = 2**24 - 1
-def pack_int24(n):
- return struct.pack('`_ in the
specification.
"""
_sock = None
- _auth_plugin_name = ''
+ _auth_plugin_name = ""
_closed = False
_secure = False
- def __init__(self, host=None, user=None, password="",
- database=None, port=0, unix_socket=None,
- charset='', sql_mode=None,
- read_default_file=None, conv=None, use_unicode=None,
- client_flag=0, cursorclass=Cursor, init_command=None,
- connect_timeout=10, ssl=None, read_default_group=None,
- compress=None, named_pipe=None,
- autocommit=False, db=None, passwd=None, local_infile=False,
- max_allowed_packet=16*1024*1024, defer_connect=False,
- auth_plugin_map=None, read_timeout=None, write_timeout=None,
- bind_address=None, binary_prefix=False, program_name=None,
- server_public_key=None):
- if use_unicode is None and sys.version_info[0] > 2:
- use_unicode = True
-
+ def __init__(
+ self,
+ *,
+ user=None, # The first four arguments is based on DB-API 2.0 recommendation.
+ password="",
+ host=None,
+ database=None,
+ unix_socket=None,
+ port=0,
+ charset="",
+ collation=None,
+ sql_mode=None,
+ read_default_file=None,
+ conv=None,
+ use_unicode=True,
+ client_flag=0,
+ cursorclass=Cursor,
+ init_command=None,
+ connect_timeout=10,
+ read_default_group=None,
+ autocommit=False,
+ local_infile=False,
+ max_allowed_packet=16 * 1024 * 1024,
+ defer_connect=False,
+ auth_plugin_map=None,
+ read_timeout=None,
+ write_timeout=None,
+ bind_address=None,
+ binary_prefix=False,
+ program_name=None,
+ server_public_key=None,
+ ssl=None,
+ ssl_ca=None,
+ ssl_cert=None,
+ ssl_disabled=None,
+ ssl_key=None,
+ ssl_key_password=None,
+ ssl_verify_cert=None,
+ ssl_verify_identity=None,
+ compress=None, # not supported
+ named_pipe=None, # not supported
+ passwd=None, # deprecated
+ db=None, # deprecated
+ ):
if db is not None and database is None:
+ # We will raise warning in 2022 or later.
+ # See https://github.com/PyMySQL/PyMySQL/issues/939
+ # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
database = db
if passwd is not None and not password:
+ # We will raise warning in 2022 or later.
+ # See https://github.com/PyMySQL/PyMySQL/issues/939
+ # warnings.warn(
+ # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
+ # )
password = passwd
if compress or named_pipe:
- raise NotImplementedError("compress and named_pipe arguments are not supported")
+ raise NotImplementedError(
+ "compress and named_pipe arguments are not supported"
+ )
self._local_infile = bool(local_infile)
if self._local_infile:
@@ -240,25 +263,42 @@ def _config(key, arg):
if not ssl:
ssl = {}
if isinstance(ssl, dict):
- for key in ["ca", "capath", "cert", "key", "cipher"]:
+ for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
value = _config("ssl-" + key, ssl.get(key))
if value:
ssl[key] = value
self.ssl = False
- if ssl:
- if not SSL_ENABLED:
- raise NotImplementedError("ssl module not found")
- self.ssl = True
- client_flag |= CLIENT.SSL
- self.ctx = self._create_ssl_ctx(ssl)
+ if not ssl_disabled:
+ if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity:
+ ssl = {
+ "ca": ssl_ca,
+ "check_hostname": bool(ssl_verify_identity),
+ "verify_mode": ssl_verify_cert
+ if ssl_verify_cert is not None
+ else False,
+ }
+ if ssl_cert is not None:
+ ssl["cert"] = ssl_cert
+ if ssl_key is not None:
+ ssl["key"] = ssl_key
+ if ssl_key_password is not None:
+ ssl["password"] = ssl_key_password
+ if ssl:
+ if not SSL_ENABLED:
+ raise NotImplementedError("ssl module not found")
+ self.ssl = True
+ client_flag |= CLIENT.SSL
+ self.ctx = self._create_ssl_ctx(ssl)
self.host = host or "localhost"
self.port = port or 3306
+ if type(self.port) is not int:
+ raise ValueError("port should be of type int")
self.user = user or DEFAULT_USER
self.password = password or b""
- if isinstance(self.password, text_type):
- self.password = self.password.encode('latin1')
+ if isinstance(self.password, str):
+ self.password = self.password.encode("latin1")
self.db = database
self.unix_socket = unix_socket
self.bind_address = bind_address
@@ -266,20 +306,15 @@ def _config(key, arg):
raise ValueError("connect_timeout should be >0 and <=31536000")
self.connect_timeout = connect_timeout or None
if read_timeout is not None and read_timeout <= 0:
- raise ValueError("read_timeout should be >= 0")
+ raise ValueError("read_timeout should be > 0")
self._read_timeout = read_timeout
if write_timeout is not None and write_timeout <= 0:
- raise ValueError("write_timeout should be >= 0")
+ raise ValueError("write_timeout should be > 0")
self._write_timeout = write_timeout
- if charset:
- self.charset = charset
- self.use_unicode = True
- else:
- self.charset = DEFAULT_CHARSET
- self.use_unicode = False
- if use_unicode is not None:
- self.use_unicode = use_unicode
+ self.charset = charset or DEFAULT_CHARSET
+ self.collation = collation
+ self.use_unicode = use_unicode
self.encoding = charset_by_name(self.charset).encoding
@@ -295,15 +330,15 @@ def _config(key, arg):
self._affected_rows = 0
self.host_info = "Not connected"
- #: specified autocommit mode. None means use server default.
+ # specified autocommit mode. None means use server default.
self.autocommit_mode = autocommit
if conv is None:
conv = converters.conversions
# Need for MySQLdb compatibility.
- self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
- self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
+ self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
+ self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
self.sql_mode = sql_mode
self.init_command = init_command
self.max_allowed_packet = max_allowed_packet
@@ -312,33 +347,56 @@ def _config(key, arg):
self.server_public_key = server_public_key
self._connect_attrs = {
- '_client_name': 'pymysql',
- '_pid': str(os.getpid()),
- '_client_version': VERSION_STRING,
+ "_client_name": "pymysql",
+ "_client_version": VERSION_STRING,
+ "_pid": str(os.getpid()),
}
+
if program_name:
self._connect_attrs["program_name"] = program_name
- elif sys.argv:
- self._connect_attrs["program_name"] = sys.argv[0]
if defer_connect:
self._sock = None
else:
self.connect()
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_info):
+ del exc_info
+ self.close()
+
def _create_ssl_ctx(self, sslp):
if isinstance(sslp, ssl.SSLContext):
return sslp
- ca = sslp.get('ca')
- capath = sslp.get('capath')
+ ca = sslp.get("ca")
+ capath = sslp.get("capath")
hasnoca = ca is None and capath is None
ctx = ssl.create_default_context(cafile=ca, capath=capath)
- ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
- ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
- if 'cert' in sslp:
- ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
- if 'cipher' in sslp:
- ctx.set_ciphers(sslp['cipher'])
+ ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
+ verify_mode_value = sslp.get("verify_mode")
+ if verify_mode_value is None:
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
+ elif isinstance(verify_mode_value, bool):
+ ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
+ else:
+ if isinstance(verify_mode_value, str):
+ verify_mode_value = verify_mode_value.lower()
+ if verify_mode_value in ("none", "0", "false", "no"):
+ ctx.verify_mode = ssl.CERT_NONE
+ elif verify_mode_value == "optional":
+ ctx.verify_mode = ssl.CERT_OPTIONAL
+ elif verify_mode_value in ("required", "1", "true", "yes"):
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ else:
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
+ if "cert" in sslp:
+ ctx.load_cert_chain(
+ sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
+ )
+ if "cipher" in sslp:
+ ctx.set_ciphers(sslp["cipher"])
ctx.options |= ssl.OP_NO_SSLv2
ctx.options |= ssl.OP_NO_SSLv3
return ctx
@@ -357,7 +415,7 @@ def close(self):
self._closed = True
if self._sock is None:
return
- send_data = struct.pack('= 5:
+ if int(self.server_version.split(".", 1)[0]) >= 5:
self.client_flag |= CLIENT.MULTI_RESULTS
if self.user is None:
raise ValueError("Did not specify a username")
charset_id = charset_by_name(self.charset).id
- if isinstance(self.user, text_type):
+ if isinstance(self.user, str):
self.user = self.user.encode(self.encoding)
- data_init = struct.pack('=5.0)
- data += authresp + b'\0'
+ data += authresp + b"\0"
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
- if isinstance(self.db, text_type):
+ if isinstance(self.db, str):
self.db = self.db.encode(self.encoding)
- data += self.db + b'\0'
+ data += self.db + b"\0"
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
- data += (plugin_name or b'') + b'\0'
+ data += (plugin_name or b"") + b"\0"
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
- connect_attrs = b''
+ connect_attrs = b""
for k, v in self._connect_attrs.items():
- k = k.encode('utf8')
- connect_attrs += struct.pack('B', len(k)) + k
- v = v.encode('utf8')
- connect_attrs += struct.pack('B', len(v)) + v
- data += struct.pack('B', len(connect_attrs)) + connect_attrs
+ k = k.encode("utf-8")
+ connect_attrs += _lenenc_int(len(k)) + k
+ v = v.encode("utf-8")
+ connect_attrs += _lenenc_int(len(v)) + v
+ data += _lenenc_int(len(connect_attrs)) + connect_attrs
self.write_packet(data)
auth_packet = self._read_packet()
@@ -854,17 +959,18 @@ def _request_authentication(self):
# if authentication method isn't accepted the first byte
# will have the octet 254
if auth_packet.is_auth_switch_request():
- if DEBUG: print("received auth switch")
+ if DEBUG:
+ print("received auth switch")
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
- auth_packet.read_uint8() # 0xfe packet identifier
+ auth_packet.read_uint8() # 0xfe packet identifier
plugin_name = auth_packet.read_string()
- if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
+ if (
+ self.server_capabilities & CLIENT.PLUGIN_AUTH
+ and plugin_name is not None
+ ):
auth_packet = self._process_auth(plugin_name, auth_packet)
else:
- # send legacy handshake
- data = _auth.scramble_old_password(self.password, self.salt) + b'\0'
- self.write_packet(data)
- auth_packet = self._read_packet()
+ raise err.OperationalError("received unknown auth switch request")
elif auth_packet.is_extra_auth_data():
if DEBUG:
print("received extra data")
@@ -874,9 +980,12 @@ def _request_authentication(self):
elif self._auth_plugin_name == "sha256_password":
auth_packet = _auth.sha256_password_auth(self, auth_packet)
else:
- raise err.OperationalError("Received extra packet for auth method %r", self._auth_plugin_name)
+ raise err.OperationalError(
+ "Received extra packet for auth method %r", self._auth_plugin_name
+ )
- if DEBUG: print("Succeed to auth")
+ if DEBUG:
+ print("Succeed to auth")
def _process_auth(self, plugin_name, auth_packet):
handler = self._get_auth_plugin_handler(plugin_name)
@@ -884,20 +993,28 @@ def _process_auth(self, plugin_name, auth_packet):
try:
return handler.authenticate(auth_packet)
except AttributeError:
- if plugin_name != b'dialog':
- raise err.OperationalError(2059, "Authentication plugin '%s'"
- " not loaded: - %r missing authenticate method" % (plugin_name, type(handler)))
+ if plugin_name != b"dialog":
+ raise err.OperationalError(
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
+ f"Authentication plugin '{plugin_name}'"
+ f" not loaded: - {type(handler)!r} missing authenticate method",
+ )
if plugin_name == b"caching_sha2_password":
return _auth.caching_sha2_password_auth(self, auth_packet)
elif plugin_name == b"sha256_password":
return _auth.sha256_password_auth(self, auth_packet)
elif plugin_name == b"mysql_native_password":
data = _auth.scramble_native_password(self.password, auth_packet.read_all())
+ elif plugin_name == b"client_ed25519":
+ data = _auth.ed25519_password(self.password, auth_packet.read_all())
elif plugin_name == b"mysql_old_password":
- data = _auth.scramble_old_password(self.password, auth_packet.read_all()) + b'\0'
+ data = (
+ _auth.scramble_old_password(self.password, auth_packet.read_all())
+ + b"\0"
+ )
elif plugin_name == b"mysql_clear_password":
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
- data = self.password + b'\0'
+ data = self.password + b"\0"
elif plugin_name == b"dialog":
pkt = auth_packet
while True:
@@ -907,27 +1024,39 @@ def _process_auth(self, plugin_name, auth_packet):
prompt = pkt.read_all()
if prompt == b"Password: ":
- self.write_packet(self.password + b'\0')
+ self.write_packet(self.password + b"\0")
elif handler:
- resp = 'no response - TypeError within plugin.prompt method'
+ resp = "no response - TypeError within plugin.prompt method"
try:
resp = handler.prompt(echo, prompt)
- self.write_packet(resp + b'\0')
+ self.write_packet(resp + b"\0")
except AttributeError:
- raise err.OperationalError(2059, "Authentication plugin '%s'" \
- " not loaded: - %r missing prompt method" % (plugin_name, handler))
+ raise err.OperationalError(
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
+ f"Authentication plugin '{plugin_name}'"
+ f" not loaded: - {handler!r} missing prompt method",
+ )
except TypeError:
- raise err.OperationalError(2061, "Authentication plugin '%s'" \
- " %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt))
+ raise err.OperationalError(
+ CR.CR_AUTH_PLUGIN_ERR,
+ f"Authentication plugin '{plugin_name}'"
+ f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
+ )
else:
- raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler))
+ raise err.OperationalError(
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
+ f"Authentication plugin '{plugin_name}' not configured",
+ )
pkt = self._read_packet()
pkt.check_error()
if pkt.is_ok_packet() or last:
break
return pkt
else:
- raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name)
+ raise err.OperationalError(
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
+ "Authentication plugin '%s' not configured" % plugin_name,
+ )
self.write_packet(data)
pkt = self._read_packet()
@@ -937,13 +1066,16 @@ def _process_auth(self, plugin_name, auth_packet):
def _get_auth_plugin_handler(self, plugin_name):
plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class and isinstance(plugin_name, bytes):
- plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
+ plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii"))
if plugin_class:
try:
handler = plugin_class(self)
except TypeError:
- raise err.OperationalError(2059, "Authentication plugin '%s'"
- " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
+ raise err.OperationalError(
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
+ f"Authentication plugin '{plugin_name}'"
+ f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
+ )
else:
handler = None
return handler
@@ -966,24 +1098,24 @@ def _get_server_information(self):
packet = self._read_packet()
data = packet.get_all_data()
- self.protocol_version = byte2int(data[i:i+1])
+ self.protocol_version = data[i]
i += 1
- server_end = data.find(b'\0', i)
- self.server_version = data[i:server_end].decode('latin1')
+ server_end = data.find(b"\0", i)
+ self.server_version = data[i:server_end].decode("latin1")
i = server_end + 1
- self.server_thread_id = struct.unpack('= i + 6:
- lang, stat, cap_h, salt_len = struct.unpack('= i + salt_len:
# salt_len includes auth_plugin_data_part_1 and filler
- self.salt += data[i:i+salt_len]
+ self.salt += data[i : i + salt_len]
i += salt_len
- i+=1
+ i += 1
# AUTH PLUGIN NAME may appear here.
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
# Due to Bug#59453 the auth-plugin-name is missing the terminating
@@ -1017,12 +1151,12 @@ def _get_server_information(self):
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
# didn't use version checks as mariadb is corrected and reports
# earlier than those two.
- server_end = data.find(b'\0', i)
- if server_end < 0: # pragma: no cover - very specific upstream bug
+ server_end = data.find(b"\0", i)
+ if server_end < 0: # pragma: no cover - very specific upstream bug
# not found \0 and last field so take it all
- self._auth_plugin_name = data[i:].decode('utf-8')
+ self._auth_plugin_name = data[i:].decode("utf-8")
else:
- self._auth_plugin_name = data[i:server_end].decode('utf-8')
+ self._auth_plugin_name = data[i:server_end].decode("utf-8")
def get_server_info(self):
return self.server_version
@@ -1039,8 +1173,7 @@ def get_server_info(self):
NotSupportedError = err.NotSupportedError
-class MySQLResult(object):
-
+class MySQLResult:
def __init__(self, connection):
"""
:type connection: Connection
@@ -1111,7 +1244,8 @@ def _read_ok_packet(self, first_packet):
def _read_load_local_packet(self, first_packet):
if not self.connection._local_infile:
raise RuntimeError(
- "**WARN**: Received LOAD_LOCAL packet but local_infile option is false.")
+ "**WARN**: Received LOAD_LOCAL packet but local_infile option is false."
+ )
load_packet = LoadLocalPacketWrapper(first_packet)
sender = LoadLocalFile(load_packet.filename, self.connection)
try:
@@ -1121,17 +1255,23 @@ def _read_load_local_packet(self, first_packet):
raise
ok_packet = self.connection._read_packet()
- if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
- raise err.OperationalError(2014, "Commands Out of Sync")
+ if (
+ not ok_packet.is_ok_packet()
+ ): # pragma: no cover - upstream induced protocol error
+ raise err.OperationalError(
+ CR.CR_COMMANDS_OUT_OF_SYNC,
+ "Commands Out of Sync",
+ )
self._read_ok_packet(ok_packet)
def _check_packet_is_eof(self, packet):
if not packet.is_eof_packet():
return False
- #TODO: Support CLIENT.DEPRECATE_EOF
+ # TODO: Support CLIENT.DEPRECATE_EOF
# 1) Add DEPRECATE_EOF to CAPABILITIES
# 2) Mask CAPABILITIES with server_capabilities
- # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
+ # 3) if server_capabilities & CLIENT.DEPRECATE_EOF:
+ # use OKPacketWrapper instead of EOFPacketWrapper
wp = EOFPacketWrapper(packet)
self.warning_count = wp.warning_count
self.has_next = wp.has_next
@@ -1165,7 +1305,20 @@ def _finish_unbuffered_query(self):
# in fact, no way to stop MySQL from sending all the data after
# executing a query, so we just spin, and wait for an EOF packet.
while self.unbuffered_active:
- packet = self.connection._read_packet()
+ try:
+ packet = self.connection._read_packet()
+ except err.OperationalError as e:
+ if e.args[0] in (
+ ER.QUERY_TIMEOUT,
+ ER.STATEMENT_TIMEOUT,
+ ):
+ # if the query timed out we can simply ignore this error
+ self.unbuffered_active = False
+ self.connection = None
+ return
+
+ raise
+
if self._check_packet_is_eof(packet):
self.unbuffered_active = False
self.connection = None # release reference to kill cyclic reference.
@@ -1195,7 +1348,8 @@ def _read_row_from_packet(self, packet):
if data is not None:
if encoding is not None:
data = data.decode(encoding)
- if DEBUG: print("DEBUG: DATA = ", data)
+ if DEBUG:
+ print("DEBUG: DATA = ", data)
if converter is not None:
data = converter(data)
row.append(data)
@@ -1209,7 +1363,7 @@ def _get_descriptions(self):
conn_encoding = self.connection.encoding
description = []
- for i in range_type(self.field_count):
+ for i in range(self.field_count):
field = self.connection._read_packet(FieldDescriptorPacket)
self.fields.append(field)
description.append(field.description())
@@ -1230,21 +1384,22 @@ def _get_descriptions(self):
encoding = conn_encoding
else:
# Integers, Dates and Times, and other basic data is encoded in ascii
- encoding = 'ascii'
+ encoding = "ascii"
else:
encoding = None
converter = self.connection.decoders.get(field_type)
if converter is converters.through:
converter = None
- if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
+ if DEBUG:
+ print(f"DEBUG: field={field}, converter={converter}")
self.converters.append((encoding, converter))
eof_packet = self.connection._read_packet()
- assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
+ assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF"
self.description = tuple(description)
-class LoadLocalFile(object):
+class LoadLocalFile:
def __init__(self, filename, connection):
self.filename = filename
self.connection = connection
@@ -1252,19 +1407,25 @@ def __init__(self, filename, connection):
def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection._sock:
- raise err.InterfaceError("(0, '')")
- conn = self.connection
+ raise err.InterfaceError(0, "")
+ conn: Connection = self.connection
try:
- with open(self.filename, 'rb') as open_file:
- packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough
+ with open(self.filename, "rb") as open_file:
+ packet_size = min(
+ conn.max_allowed_packet, 16 * 1024
+ ) # 16KB is efficient enough
while True:
chunk = open_file.read(packet_size)
if not chunk:
break
conn.write_packet(chunk)
- except IOError:
- raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
+ except OSError:
+ raise err.OperationalError(
+ ER.FILE_NOT_FOUND,
+ f"Can't find file '{self.filename}'",
+ )
finally:
- # send the empty packet to signify we are done sending data
- conn.write_packet(b'')
+ if not conn._closed:
+ # send the empty packet to signify we are done sending data
+ conn.write_packet(b"")
diff --git a/pymysql/constants/CLIENT.py b/pymysql/constants/CLIENT.py
index b42f1523c..34fe57a5d 100644
--- a/pymysql/constants/CLIENT.py
+++ b/pymysql/constants/CLIENT.py
@@ -21,9 +21,16 @@
CONNECT_ATTRS = 1 << 20
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
CAPABILITIES = (
- LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS
- | SECURE_CONNECTION | MULTI_RESULTS
- | PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA | CONNECT_ATTRS)
+ LONG_PASSWORD
+ | LONG_FLAG
+ | PROTOCOL_41
+ | TRANSACTIONS
+ | SECURE_CONNECTION
+ | MULTI_RESULTS
+ | PLUGIN_AUTH
+ | PLUGIN_AUTH_LENENC_CLIENT_DATA
+ | CONNECT_ATTRS
+)
# Not done yet
HANDLE_EXPIRED_PASSWORDS = 1 << 22
diff --git a/pymysql/constants/COMMAND.py b/pymysql/constants/COMMAND.py
index 1da275533..2d98850b8 100644
--- a/pymysql/constants/COMMAND.py
+++ b/pymysql/constants/COMMAND.py
@@ -1,4 +1,3 @@
-
COM_SLEEP = 0x00
COM_QUIT = 0x01
COM_INIT_DB = 0x02
@@ -9,12 +8,12 @@
COM_REFRESH = 0x07
COM_SHUTDOWN = 0x08
COM_STATISTICS = 0x09
-COM_PROCESS_INFO = 0x0a
-COM_CONNECT = 0x0b
-COM_PROCESS_KILL = 0x0c
-COM_DEBUG = 0x0d
-COM_PING = 0x0e
-COM_TIME = 0x0f
+COM_PROCESS_INFO = 0x0A
+COM_CONNECT = 0x0B
+COM_PROCESS_KILL = 0x0C
+COM_DEBUG = 0x0D
+COM_PING = 0x0E
+COM_TIME = 0x0F
COM_DELAYED_INSERT = 0x10
COM_CHANGE_USER = 0x11
COM_BINLOG_DUMP = 0x12
@@ -25,9 +24,9 @@
COM_STMT_EXECUTE = 0x17
COM_STMT_SEND_LONG_DATA = 0x18
COM_STMT_CLOSE = 0x19
-COM_STMT_RESET = 0x1a
-COM_SET_OPTION = 0x1b
-COM_STMT_FETCH = 0x1c
-COM_DAEMON = 0x1d
-COM_BINLOG_DUMP_GTID = 0x1e
-COM_END = 0x1f
+COM_STMT_RESET = 0x1A
+COM_SET_OPTION = 0x1B
+COM_STMT_FETCH = 0x1C
+COM_DAEMON = 0x1D
+COM_BINLOG_DUMP_GTID = 0x1E
+COM_END = 0x1F
diff --git a/pymysql/constants/CR.py b/pymysql/constants/CR.py
index 48ca956ec..deae977e5 100644
--- a/pymysql/constants/CR.py
+++ b/pymysql/constants/CR.py
@@ -1,68 +1,79 @@
# flake8: noqa
# errmsg.h
-CR_ERROR_FIRST = 2000
-CR_UNKNOWN_ERROR = 2000
-CR_SOCKET_CREATE_ERROR = 2001
-CR_CONNECTION_ERROR = 2002
-CR_CONN_HOST_ERROR = 2003
-CR_IPSOCK_ERROR = 2004
-CR_UNKNOWN_HOST = 2005
-CR_SERVER_GONE_ERROR = 2006
-CR_VERSION_ERROR = 2007
-CR_OUT_OF_MEMORY = 2008
-CR_WRONG_HOST_INFO = 2009
+CR_ERROR_FIRST = 2000
+CR_UNKNOWN_ERROR = 2000
+CR_SOCKET_CREATE_ERROR = 2001
+CR_CONNECTION_ERROR = 2002
+CR_CONN_HOST_ERROR = 2003
+CR_IPSOCK_ERROR = 2004
+CR_UNKNOWN_HOST = 2005
+CR_SERVER_GONE_ERROR = 2006
+CR_VERSION_ERROR = 2007
+CR_OUT_OF_MEMORY = 2008
+CR_WRONG_HOST_INFO = 2009
CR_LOCALHOST_CONNECTION = 2010
-CR_TCP_CONNECTION = 2011
+CR_TCP_CONNECTION = 2011
CR_SERVER_HANDSHAKE_ERR = 2012
-CR_SERVER_LOST = 2013
+CR_SERVER_LOST = 2013
CR_COMMANDS_OUT_OF_SYNC = 2014
CR_NAMEDPIPE_CONNECTION = 2015
-CR_NAMEDPIPEWAIT_ERROR = 2016
-CR_NAMEDPIPEOPEN_ERROR = 2017
+CR_NAMEDPIPEWAIT_ERROR = 2016
+CR_NAMEDPIPEOPEN_ERROR = 2017
CR_NAMEDPIPESETSTATE_ERROR = 2018
-CR_CANT_READ_CHARSET = 2019
+CR_CANT_READ_CHARSET = 2019
CR_NET_PACKET_TOO_LARGE = 2020
-CR_EMBEDDED_CONNECTION = 2021
-CR_PROBE_SLAVE_STATUS = 2022
-CR_PROBE_SLAVE_HOSTS = 2023
-CR_PROBE_SLAVE_CONNECT = 2024
+CR_EMBEDDED_CONNECTION = 2021
+CR_PROBE_SLAVE_STATUS = 2022
+CR_PROBE_SLAVE_HOSTS = 2023
+CR_PROBE_SLAVE_CONNECT = 2024
CR_PROBE_MASTER_CONNECT = 2025
CR_SSL_CONNECTION_ERROR = 2026
-CR_MALFORMED_PACKET = 2027
-CR_WRONG_LICENSE = 2028
+CR_MALFORMED_PACKET = 2027
+CR_WRONG_LICENSE = 2028
-CR_NULL_POINTER = 2029
-CR_NO_PREPARE_STMT = 2030
-CR_PARAMS_NOT_BOUND = 2031
-CR_DATA_TRUNCATED = 2032
+CR_NULL_POINTER = 2029
+CR_NO_PREPARE_STMT = 2030
+CR_PARAMS_NOT_BOUND = 2031
+CR_DATA_TRUNCATED = 2032
CR_NO_PARAMETERS_EXISTS = 2033
CR_INVALID_PARAMETER_NO = 2034
-CR_INVALID_BUFFER_USE = 2035
+CR_INVALID_BUFFER_USE = 2035
CR_UNSUPPORTED_PARAM_TYPE = 2036
-CR_SHARED_MEMORY_CONNECTION = 2037
-CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = 2038
-CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = 2039
+CR_SHARED_MEMORY_CONNECTION = 2037
+CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = 2038
+CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = 2039
CR_SHARED_MEMORY_CONNECT_FILE_MAP_ERROR = 2040
-CR_SHARED_MEMORY_CONNECT_MAP_ERROR = 2041
-CR_SHARED_MEMORY_FILE_MAP_ERROR = 2042
-CR_SHARED_MEMORY_MAP_ERROR = 2043
-CR_SHARED_MEMORY_EVENT_ERROR = 2044
+CR_SHARED_MEMORY_CONNECT_MAP_ERROR = 2041
+CR_SHARED_MEMORY_FILE_MAP_ERROR = 2042
+CR_SHARED_MEMORY_MAP_ERROR = 2043
+CR_SHARED_MEMORY_EVENT_ERROR = 2044
CR_SHARED_MEMORY_CONNECT_ABANDONED_ERROR = 2045
-CR_SHARED_MEMORY_CONNECT_SET_ERROR = 2046
-CR_CONN_UNKNOW_PROTOCOL = 2047
-CR_INVALID_CONN_HANDLE = 2048
-CR_SECURE_AUTH = 2049
-CR_FETCH_CANCELED = 2050
-CR_NO_DATA = 2051
-CR_NO_STMT_METADATA = 2052
-CR_NO_RESULT_SET = 2053
-CR_NOT_IMPLEMENTED = 2054
-CR_SERVER_LOST_EXTENDED = 2055
-CR_STMT_CLOSED = 2056
-CR_NEW_STMT_METADATA = 2057
-CR_ALREADY_CONNECTED = 2058
-CR_AUTH_PLUGIN_CANNOT_LOAD = 2059
-CR_DUPLICATE_CONNECTION_ATTR = 2060
-CR_AUTH_PLUGIN_ERR = 2061
-CR_ERROR_LAST = 2061
+CR_SHARED_MEMORY_CONNECT_SET_ERROR = 2046
+CR_CONN_UNKNOW_PROTOCOL = 2047
+CR_INVALID_CONN_HANDLE = 2048
+CR_SECURE_AUTH = 2049
+CR_FETCH_CANCELED = 2050
+CR_NO_DATA = 2051
+CR_NO_STMT_METADATA = 2052
+CR_NO_RESULT_SET = 2053
+CR_NOT_IMPLEMENTED = 2054
+CR_SERVER_LOST_EXTENDED = 2055
+CR_STMT_CLOSED = 2056
+CR_NEW_STMT_METADATA = 2057
+CR_ALREADY_CONNECTED = 2058
+CR_AUTH_PLUGIN_CANNOT_LOAD = 2059
+CR_DUPLICATE_CONNECTION_ATTR = 2060
+CR_AUTH_PLUGIN_ERR = 2061
+CR_INSECURE_API_ERR = 2062
+CR_FILE_NAME_TOO_LONG = 2063
+CR_SSL_FIPS_MODE_ERR = 2064
+CR_DEPRECATED_COMPRESSION_NOT_SUPPORTED = 2065
+CR_COMPRESSION_WRONGLY_CONFIGURED = 2066
+CR_KERBEROS_USER_NOT_FOUND = 2067
+CR_LOAD_DATA_LOCAL_INFILE_REJECTED = 2068
+CR_LOAD_DATA_LOCAL_INFILE_REALPATH_FAIL = 2069
+CR_DNS_SRV_LOOKUP_FAILED = 2070
+CR_MANDATORY_TRACKER_NOT_FOUND = 2071
+CR_INVALID_FACTOR_NO = 2072
+CR_ERROR_LAST = 2072
diff --git a/pymysql/constants/ER.py b/pymysql/constants/ER.py
index 79b88afbe..98729d12d 100644
--- a/pymysql/constants/ER.py
+++ b/pymysql/constants/ER.py
@@ -1,4 +1,3 @@
-
ERROR_FIRST = 1000
HASHCHK = 1000
NISAMCHK = 1001
@@ -471,5 +470,8 @@
WRONG_STRING_LENGTH = 1468
ERROR_LAST = 1468
+# MariaDB only
+STATEMENT_TIMEOUT = 1969
+QUERY_TIMEOUT = 3024
# https://github.com/PyMySQL/PyMySQL/issues/607
CONSTRAINT_FAILED = 4025
diff --git a/pymysql/constants/FIELD_TYPE.py b/pymysql/constants/FIELD_TYPE.py
index 51bd5143b..b8b448660 100644
--- a/pymysql/constants/FIELD_TYPE.py
+++ b/pymysql/constants/FIELD_TYPE.py
@@ -1,5 +1,3 @@
-
-
DECIMAL = 0
TINY = 1
SHORT = 2
diff --git a/pymysql/constants/SERVER_STATUS.py b/pymysql/constants/SERVER_STATUS.py
index 6f5d56630..8f8d77688 100644
--- a/pymysql/constants/SERVER_STATUS.py
+++ b/pymysql/constants/SERVER_STATUS.py
@@ -1,4 +1,3 @@
-
SERVER_STATUS_IN_TRANS = 1
SERVER_STATUS_AUTOCOMMIT = 2
SERVER_MORE_RESULTS_EXISTS = 8
diff --git a/pymysql/converters.py b/pymysql/converters.py
index bf1db9d77..dbf97ca75 100644
--- a/pymysql/converters.py
+++ b/pymysql/converters.py
@@ -1,12 +1,10 @@
-from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON, unichr
-
import datetime
from decimal import Decimal
import re
import time
-from .constants import FIELD_TYPE, FLAG
-from .charset import charset_by_id, charset_to_encoding
+from .err import ProgrammingError
+from .constants import FIELD_TYPE
def escape_item(val, charset, mapping=None):
@@ -17,7 +15,7 @@ def escape_item(val, charset, mapping=None):
# Fallback to default when no encoder found
if not encoder:
try:
- encoder = mapping[text_type]
+ encoder = mapping[str]
except KeyError:
raise TypeError("no default type converter defined")
@@ -27,12 +25,10 @@ def escape_item(val, charset, mapping=None):
val = encoder(val, mapping)
return val
+
def escape_dict(val, charset, mapping=None):
- n = {}
- for k, v in val.items():
- quoted = escape_item(v, charset, mapping)
- n[k] = quoted
- return n
+ raise TypeError("dict can not be used as parameter")
+
def escape_sequence(val, charset, mapping=None):
n = []
@@ -41,87 +37,63 @@ def escape_sequence(val, charset, mapping=None):
n.append(quoted)
return "(" + ",".join(n) + ")"
+
def escape_set(val, charset, mapping=None):
- return ','.join([escape_item(x, charset, mapping) for x in val])
+ return ",".join([escape_item(x, charset, mapping) for x in val])
+
def escape_bool(value, mapping=None):
return str(int(value))
-def escape_object(value, mapping=None):
- return str(value)
def escape_int(value, mapping=None):
return str(value)
-def escape_float(value, mapping=None):
- return ('%.15g' % value)
-
-_escape_table = [unichr(x) for x in range(128)]
-_escape_table[0] = u'\\0'
-_escape_table[ord('\\')] = u'\\\\'
-_escape_table[ord('\n')] = u'\\n'
-_escape_table[ord('\r')] = u'\\r'
-_escape_table[ord('\032')] = u'\\Z'
-_escape_table[ord('"')] = u'\\"'
-_escape_table[ord("'")] = u"\\'"
-
-def _escape_unicode(value, mapping=None):
- """escapes *value* without adding quote.
- Value should be unicode
- """
- return value.translate(_escape_table)
+def escape_float(value, mapping=None):
+ s = repr(value)
+ if s in ("inf", "-inf", "nan"):
+ raise ProgrammingError("%s can not be used with MySQL" % s)
+ if "e" not in s:
+ s += "e0"
+ return s
-if PY2:
- def escape_string(value, mapping=None):
- """escape_string escapes *value* but not surround it with quotes.
- Value should be bytes or unicode.
- """
- if isinstance(value, unicode):
- return _escape_unicode(value)
- assert isinstance(value, (bytes, bytearray))
- value = value.replace('\\', '\\\\')
- value = value.replace('\0', '\\0')
- value = value.replace('\n', '\\n')
- value = value.replace('\r', '\\r')
- value = value.replace('\032', '\\Z')
- value = value.replace("'", "\\'")
- value = value.replace('"', '\\"')
- return value
+_escape_table = [chr(x) for x in range(128)]
+_escape_table[0] = "\\0"
+_escape_table[ord("\\")] = "\\\\"
+_escape_table[ord("\n")] = "\\n"
+_escape_table[ord("\r")] = "\\r"
+_escape_table[ord("\032")] = "\\Z"
+_escape_table[ord('"')] = '\\"'
+_escape_table[ord("'")] = "\\'"
- def escape_bytes_prefixed(value, mapping=None):
- assert isinstance(value, (bytes, bytearray))
- return b"_binary'%s'" % escape_string(value)
- def escape_bytes(value, mapping=None):
- assert isinstance(value, (bytes, bytearray))
- return b"'%s'" % escape_string(value)
+def escape_string(value, mapping=None):
+ """escapes *value* without adding quote.
-else:
- escape_string = _escape_unicode
+ Value should be unicode
+ """
+ return value.translate(_escape_table)
- # On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow.
- # (fixed in Python 3.6, http://bugs.python.org/issue24870)
- # Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff.
- # We can escape special chars and surrogateescape at once.
- _escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)]
- def escape_bytes_prefixed(value, mapping=None):
- return "_binary'%s'" % value.decode('latin1').translate(_escape_bytes_table)
+def escape_bytes_prefixed(value, mapping=None):
+ return "_binary'%s'" % value.decode("ascii", "surrogateescape").translate(
+ _escape_table
+ )
- def escape_bytes(value, mapping=None):
- return "'%s'" % value.decode('latin1').translate(_escape_bytes_table)
+def escape_bytes(value, mapping=None):
+ return "'%s'" % value.decode("ascii", "surrogateescape").translate(_escape_table)
-def escape_unicode(value, mapping=None):
- return u"'%s'" % _escape_unicode(value)
def escape_str(value, mapping=None):
return "'%s'" % escape_string(str(value), mapping)
+
def escape_None(value, mapping=None):
- return 'NULL'
+ return "NULL"
+
def escape_timedelta(obj, mapping=None):
seconds = int(obj.seconds) % 60
@@ -133,6 +105,7 @@ def escape_timedelta(obj, mapping=None):
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
return fmt.format(hours, minutes, seconds, obj.microseconds)
+
def escape_time(obj, mapping=None):
if obj.microsecond:
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
@@ -140,48 +113,61 @@ def escape_time(obj, mapping=None):
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
return fmt.format(obj)
+
def escape_datetime(obj, mapping=None):
if obj.microsecond:
- fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
+ fmt = (
+ "'{0.year:04}-{0.month:02}-{0.day:02}"
+ + " {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
+ )
else:
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'"
return fmt.format(obj)
+
def escape_date(obj, mapping=None):
fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
return fmt.format(obj)
+
def escape_struct_time(obj, mapping=None):
return escape_datetime(datetime.datetime(*obj[:6]))
+
+def Decimal2Literal(o, d):
+ return format(o, "f")
+
+
def _convert_second_fraction(s):
if not s:
return 0
# Pad zeros to ensure the fraction length in microseconds
- s = s.ljust(6, '0')
+ s = s.ljust(6, "0")
return int(s[:6])
-DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
+
+DATETIME_RE = re.compile(
+ r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?"
+)
def convert_datetime(obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
- >>> datetime_or_None('2007-02-25 23:06:20')
+ >>> convert_datetime('2007-02-25 23:06:20')
datetime.datetime(2007, 2, 25, 23, 6, 20)
- >>> datetime_or_None('2007-02-25T23:06:20')
+ >>> convert_datetime('2007-02-25T23:06:20')
datetime.datetime(2007, 2, 25, 23, 6, 20)
- Illegal values are returned as None:
-
- >>> datetime_or_None('2007-02-31T23:06:20') is None
- True
- >>> datetime_or_None('0000-00-00 00:00:00') is None
- True
+ Illegal values are returned as str:
+ >>> convert_datetime('2007-02-31T23:06:20')
+ '2007-02-31T23:06:20'
+ >>> convert_datetime('0000-00-00 00:00:00')
+ '0000-00-00 00:00:00'
"""
- if not PY2 and isinstance(obj, (bytes, bytearray)):
- obj = obj.decode('ascii')
+ if isinstance(obj, (bytes, bytearray)):
+ obj = obj.decode("ascii")
m = DATETIME_RE.match(obj)
if not m:
@@ -190,32 +176,33 @@ def convert_datetime(obj):
try:
groups = list(m.groups())
groups[-1] = _convert_second_fraction(groups[-1])
- return datetime.datetime(*[ int(x) for x in groups ])
+ return datetime.datetime(*[int(x) for x in groups])
except ValueError:
return convert_date(obj)
+
TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
def convert_timedelta(obj):
"""Returns a TIME column as a timedelta object:
- >>> timedelta_or_None('25:06:17')
- datetime.timedelta(1, 3977)
- >>> timedelta_or_None('-25:06:17')
- datetime.timedelta(-2, 83177)
+ >>> convert_timedelta('25:06:17')
+ datetime.timedelta(days=1, seconds=3977)
+ >>> convert_timedelta('-25:06:17')
+ datetime.timedelta(days=-2, seconds=82423)
- Illegal values are returned as None:
+ Illegal values are returned as string:
- >>> timedelta_or_None('random crap') is None
- True
+ >>> convert_timedelta('random crap')
+ 'random crap'
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function.
"""
- if not PY2 and isinstance(obj, (bytes, bytearray)):
- obj = obj.decode('ascii')
+ if isinstance(obj, (bytes, bytearray)):
+ obj = obj.decode("ascii")
m = TIMEDELTA_RE.match(obj)
if not m:
@@ -227,31 +214,35 @@ def convert_timedelta(obj):
negate = -1 if groups[0] else 1
hours, minutes, seconds, microseconds = groups[1:]
- tdelta = datetime.timedelta(
- hours = int(hours),
- minutes = int(minutes),
- seconds = int(seconds),
- microseconds = int(microseconds)
- ) * negate
+ tdelta = (
+ datetime.timedelta(
+ hours=int(hours),
+ minutes=int(minutes),
+ seconds=int(seconds),
+ microseconds=int(microseconds),
+ )
+ * negate
+ )
return tdelta
except ValueError:
return obj
+
TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
def convert_time(obj):
"""Returns a TIME column as a time object:
- >>> time_or_None('15:06:17')
+ >>> convert_time('15:06:17')
datetime.time(15, 6, 17)
- Illegal values are returned as None:
+ Illegal values are returned as str:
- >>> time_or_None('-25:06:17') is None
- True
- >>> time_or_None('random crap') is None
- True
+ >>> convert_time('-25:06:17')
+ '-25:06:17'
+ >>> convert_time('random crap')
+ 'random crap'
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
can accept values as (+|-)DD HH:MM:SS. The latter format will not
@@ -262,8 +253,8 @@ def convert_time(obj):
to be treated as time-of-day and not a time offset, then you can
use set this function as the converter for FIELD_TYPE.TIME.
"""
- if not PY2 and isinstance(obj, (bytes, bytearray)):
- obj = obj.decode('ascii')
+ if isinstance(obj, (bytes, bytearray)):
+ obj = obj.decode("ascii")
m = TIME_RE.match(obj)
if not m:
@@ -273,8 +264,12 @@ def convert_time(obj):
groups = list(m.groups())
groups[-1] = _convert_second_fraction(groups[-1])
hours, minutes, seconds, microseconds = groups
- return datetime.time(hour=int(hours), minute=int(minutes),
- second=int(seconds), microsecond=int(microseconds))
+ return datetime.time(
+ hour=int(hours),
+ minute=int(minutes),
+ second=int(seconds),
+ microsecond=int(microseconds),
+ )
except ValueError:
return obj
@@ -282,70 +277,29 @@ def convert_time(obj):
def convert_date(obj):
"""Returns a DATE column as a date object:
- >>> date_or_None('2007-02-26')
+ >>> convert_date('2007-02-26')
datetime.date(2007, 2, 26)
- Illegal values are returned as None:
-
- >>> date_or_None('2007-02-31') is None
- True
- >>> date_or_None('0000-00-00') is None
- True
+ Illegal values are returned as str:
+ >>> convert_date('2007-02-31')
+ '2007-02-31'
+ >>> convert_date('0000-00-00')
+ '0000-00-00'
"""
- if not PY2 and isinstance(obj, (bytes, bytearray)):
- obj = obj.decode('ascii')
+ if isinstance(obj, (bytes, bytearray)):
+ obj = obj.decode("ascii")
try:
- return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
+ return datetime.date(*[int(x) for x in obj.split("-", 2)])
except ValueError:
return obj
-def convert_mysql_timestamp(timestamp):
- """Convert a MySQL TIMESTAMP to a Timestamp object.
-
- MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
-
- >>> mysql_timestamp_converter('2007-02-25 22:32:17')
- datetime.datetime(2007, 2, 25, 22, 32, 17)
-
- MySQL < 4.1 uses a big string of numbers:
-
- >>> mysql_timestamp_converter('20070225223217')
- datetime.datetime(2007, 2, 25, 22, 32, 17)
-
- Illegal values are returned as None:
-
- >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
- True
- >>> mysql_timestamp_converter('00000000000000') is None
- True
-
- """
- if not PY2 and isinstance(timestamp, (bytes, bytearray)):
- timestamp = timestamp.decode('ascii')
- if timestamp[4] == '-':
- return convert_datetime(timestamp)
- timestamp += "0"*(14-len(timestamp)) # padding
- year, month, day, hour, minute, second = \
- int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
- int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14])
- try:
- return datetime.datetime(year, month, day, hour, minute, second)
- except ValueError:
- return timestamp
-
-def convert_set(s):
- if isinstance(s, (bytes, bytearray)):
- return set(s.split(b","))
- return set(s.split(","))
-
-
def through(x):
return x
-#def convert_bit(b):
+# def convert_bit(b):
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
# return struct.unpack(">Q", b)[0]
#
@@ -354,28 +308,12 @@ def through(x):
convert_bit = through
-def convert_characters(connection, field, data):
- field_charset = charset_by_id(field.charsetnr).name
- encoding = charset_to_encoding(field_charset)
- if field.flags & FLAG.SET:
- return convert_set(data.decode(encoding))
- if field.flags & FLAG.BINARY:
- return data
-
- if connection.use_unicode:
- data = data.decode(encoding)
- elif connection.charset != field_charset:
- data = data.decode(encoding)
- data = data.encode(connection.encoding)
- return data
-
encoders = {
bool: escape_bool,
int: escape_int,
- long_type: escape_int,
float: escape_float,
str: escape_str,
- text_type: escape_unicode,
+ bytes: escape_bytes,
tuple: escape_sequence,
list: escape_sequence,
set: escape_sequence,
@@ -387,11 +325,9 @@ def convert_characters(connection, field, data):
datetime.timedelta: escape_timedelta,
datetime.time: escape_time,
time.struct_time: escape_struct_time,
- Decimal: escape_object,
+ Decimal: Decimal2Literal,
}
-if not PY2 or JYTHON or IRONPYTHON:
- encoders[bytes] = escape_bytes
decoders = {
FIELD_TYPE.BIT: convert_bit,
@@ -403,11 +339,10 @@ def convert_characters(connection, field, data):
FIELD_TYPE.LONGLONG: int,
FIELD_TYPE.INT24: int,
FIELD_TYPE.YEAR: int,
- FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
+ FIELD_TYPE.TIMESTAMP: convert_datetime,
FIELD_TYPE.DATETIME: convert_datetime,
FIELD_TYPE.TIME: convert_timedelta,
FIELD_TYPE.DATE: convert_date,
- FIELD_TYPE.SET: convert_set,
FIELD_TYPE.BLOB: through,
FIELD_TYPE.TINY_BLOB: through,
FIELD_TYPE.MEDIUM_BLOB: through,
@@ -424,3 +359,5 @@ def convert_characters(connection, field, data):
conversions = encoders.copy()
conversions.update(decoders)
Thing2Literal = escape_str
+
+# Run doctests with `pytest --doctest-modules pymysql/converters.py`
diff --git a/pymysql/cursors.py b/pymysql/cursors.py
index cc169987b..8be05ca23 100644
--- a/pymysql/cursors.py
+++ b/pymysql/cursors.py
@@ -1,26 +1,22 @@
-# -*- coding: utf-8 -*-
-from __future__ import print_function, absolute_import
-from functools import partial
import re
import warnings
-
-from ._compat import range_type, text_type, PY2
from . import err
#: Regular expression for :meth:`Cursor.executemany`.
-#: executemany only suports simple bulk insert.
+#: executemany only supports simple bulk insert.
#: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
- r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" +
- r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
- r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
- re.IGNORECASE | re.DOTALL)
+ r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
+ + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
+ + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
+ re.IGNORECASE | re.DOTALL,
+)
-class Cursor(object):
+class Cursor:
"""
- This is the object you use to interact with the database.
+ This is the object used to interact with the database.
Do not create an instance of a Cursor yourself. Call
connections.Connection.cursor().
@@ -35,10 +31,9 @@ class Cursor(object):
#: Default value of max_allowed_packet is 1048576.
max_stmt_length = 1024000
- _defer_warnings = False
-
def __init__(self, connection):
self.connection = connection
+ self.warning_count = 0
self.description = None
self.rownumber = 0
self.rowcount = -1
@@ -46,7 +41,6 @@ def __init__(self, connection):
self._executed = None
self._result = None
self._rows = None
- self._warnings_handled = False
def close(self):
"""
@@ -87,12 +81,9 @@ def setoutputsizes(self, *args):
"""Does nothing, required by DB API."""
def _nextset(self, unbuffered=False):
- """Get the next query set"""
+ """Get the next query set."""
conn = self._get_db()
current_result = self._result
- # for unbuffered queries warnings are only available once whole result has been read
- if unbuffered:
- self._show_warnings()
if current_result is None or current_result is not conn._result:
return None
if not current_result.has_next:
@@ -106,42 +97,33 @@ def _nextset(self, unbuffered=False):
def nextset(self):
return self._nextset(False)
- def _ensure_bytes(self, x, encoding=None):
- if isinstance(x, text_type):
- x = x.encode(encoding)
- elif isinstance(x, (tuple, list)):
- x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
- return x
-
def _escape_args(self, args, conn):
- ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
-
if isinstance(args, (tuple, list)):
- if PY2:
- args = tuple(map(ensure_bytes, args))
return tuple(conn.literal(arg) for arg in args)
elif isinstance(args, dict):
- if PY2:
- args = dict((ensure_bytes(key), ensure_bytes(val)) for
- (key, val) in args.items())
- return dict((key, conn.literal(val)) for (key, val) in args.items())
+ return {key: conn.literal(val) for (key, val) in args.items()}
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
- if PY2:
- args = ensure_bytes(args)
return conn.escape(args)
def mogrify(self, query, args=None):
"""
- Returns the exact string that is sent to the database by calling the
+ Returns the exact string that would be sent to the database by calling the
execute() method.
+ :param query: Query to mogrify.
+ :type query: str
+
+ :param args: Parameters used with query. (optional)
+ :type args: tuple, list or dict
+
+ :return: The query with argument binding applied.
+ :rtype: str
+
This method follows the extension to the DB API 2.0 followed by Psycopg.
"""
conn = self._get_db()
- if PY2: # Use bytes on Python 2 always
- query = self._ensure_bytes(query, encoding=conn.encoding)
if args is not None:
query = query % self._escape_args(args, conn)
@@ -149,14 +131,15 @@ def mogrify(self, query, args=None):
return query
def execute(self, query, args=None):
- """Execute a query
+ """Execute a query.
- :param str query: Query to execute.
+ :param query: Query to execute.
+ :type query: str
- :param args: parameters used with query. (optional)
+ :param args: Parameters used with query. (optional)
:type args: tuple, list or dict
- :return: Number of affected rows
+ :return: Number of affected rows.
:rtype: int
If args is a list or tuple, %s can be used as a placeholder in the query.
@@ -172,12 +155,16 @@ def execute(self, query, args=None):
return result
def executemany(self, query, args):
- # type: (str, list) -> int
- """Run several data against one query
+ """Run several data against one query.
+
+ :param query: Query to execute.
+ :type query: str
+
+ :param args: Sequence of sequences or mappings. It is used as parameter.
+ :type args: tuple or list
- :param query: query to execute on server
- :param args: Sequence of sequences or mappings. It is used as parameter.
:return: Number of rows affected, if any.
+ :rtype: int or None
This method improves performance on multiple-row INSERT and
REPLACE. Otherwise it is equivalent to looping over args with
@@ -190,57 +177,58 @@ def executemany(self, query, args):
if m:
q_prefix = m.group(1) % ()
q_values = m.group(2).rstrip()
- q_postfix = m.group(3) or ''
- assert q_values[0] == '(' and q_values[-1] == ')'
- return self._do_execute_many(q_prefix, q_values, q_postfix, args,
- self.max_stmt_length,
- self._get_db().encoding)
+ q_postfix = m.group(3) or ""
+ assert q_values[0] == "(" and q_values[-1] == ")"
+ return self._do_execute_many(
+ q_prefix,
+ q_values,
+ q_postfix,
+ args,
+ self.max_stmt_length,
+ self._get_db().encoding,
+ )
self.rowcount = sum(self.execute(query, arg) for arg in args)
return self.rowcount
- def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
+ def _do_execute_many(
+ self, prefix, values, postfix, args, max_stmt_length, encoding
+ ):
conn = self._get_db()
escape = self._escape_args
- if isinstance(prefix, text_type):
+ if isinstance(prefix, str):
prefix = prefix.encode(encoding)
- if PY2 and isinstance(values, text_type):
- values = values.encode(encoding)
- if isinstance(postfix, text_type):
+ if isinstance(postfix, str):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
- if isinstance(v, text_type):
- if PY2:
- v = v.encode(encoding)
- else:
- v = v.encode(encoding, 'surrogateescape')
+ if isinstance(v, str):
+ v = v.encode(encoding, "surrogateescape")
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
- if isinstance(v, text_type):
- if PY2:
- v = v.encode(encoding)
- else:
- v = v.encode(encoding, 'surrogateescape')
+ if isinstance(v, str):
+ v = v.encode(encoding, "surrogateescape")
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql + postfix)
sql = bytearray(prefix)
else:
- sql += b','
+ sql += b","
sql += v
rows += self.execute(sql + postfix)
self.rowcount = rows
return rows
def callproc(self, procname, args=()):
- """Execute stored procedure procname with args
+ """Execute stored procedure procname with args.
- procname -- string, name of procedure to execute on server
+ :param procname: Name of procedure to execute on server.
+ :type procname: str
- args -- Sequence of parameters to use with procedure
+ :param args: Sequence of parameters to use with procedure.
+ :type args: tuple or list
Returns the original args.
@@ -265,20 +253,25 @@ def callproc(self, procname, args=()):
"""
conn = self._get_db()
if args:
- fmt = '@_{0}_%d=%s'.format(procname)
- self._query('SET %s' % ','.join(fmt % (index, conn.escape(arg))
- for index, arg in enumerate(args)))
+ fmt = f"@_{procname}_%d=%s"
+ self._query(
+ "SET %s"
+ % ",".join(
+ fmt % (index, conn.escape(arg)) for index, arg in enumerate(args)
+ )
+ )
self.nextset()
- q = "CALL %s(%s)" % (procname,
- ','.join(['@_%s_%d' % (procname, i)
- for i in range_type(len(args))]))
+ q = "CALL {}({})".format(
+ procname,
+ ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]),
+ )
self._query(q)
self._executed = q
return args
def fetchone(self):
- """Fetch the next row"""
+ """Fetch the next row."""
self._check_executed()
if self._rows is None or self.rownumber >= len(self._rows):
return None
@@ -287,32 +280,34 @@ def fetchone(self):
return result
def fetchmany(self, size=None):
- """Fetch several rows"""
+ """Fetch several rows."""
self._check_executed()
if self._rows is None:
+ # Django expects () for EOF.
+ # https://github.com/django/django/blob/0c1518ee429b01c145cf5b34eab01b0b92f8c246/django/db/backends/mysql/features.py#L8
return ()
end = self.rownumber + (size or self.arraysize)
- result = self._rows[self.rownumber:end]
+ result = self._rows[self.rownumber : end]
self.rownumber = min(end, len(self._rows))
return result
def fetchall(self):
- """Fetch all the rows"""
+ """Fetch all the rows."""
self._check_executed()
if self._rows is None:
- return ()
+ return []
if self.rownumber:
- result = self._rows[self.rownumber:]
+ result = self._rows[self.rownumber :]
else:
result = self._rows
self.rownumber = len(self._rows)
return result
- def scroll(self, value, mode='relative'):
+ def scroll(self, value, mode="relative"):
self._check_executed()
- if mode == 'relative':
+ if mode == "relative":
r = self.rownumber + value
- elif mode == 'absolute':
+ elif mode == "absolute":
r = value
else:
raise err.ProgrammingError("unknown scroll mode %s" % mode)
@@ -323,7 +318,6 @@ def scroll(self, value, mode='relative'):
def _query(self, q):
conn = self._get_db()
- self._last_executed = q
self._clear_result()
conn.query(q)
self._do_get_result()
@@ -334,6 +328,7 @@ def _clear_result(self):
self._result = None
self.rowcount = 0
+ self.warning_count = 0
self.description = None
self.lastrowid = None
self._rows = None
@@ -344,57 +339,57 @@ def _do_get_result(self):
self._result = result = conn._result
self.rowcount = result.affected_rows
+ self.warning_count = result.warning_count
self.description = result.description
self.lastrowid = result.insert_id
self._rows = result.rows
- self._warnings_handled = False
-
- if not self._defer_warnings:
- self._show_warnings()
-
- def _show_warnings(self):
- if self._warnings_handled:
- return
- self._warnings_handled = True
- if self._result and (self._result.has_next or not self._result.warning_count):
- return
- ws = self._get_db().show_warnings()
- if ws is None:
- return
- for w in ws:
- msg = w[-1]
- if PY2:
- if isinstance(msg, unicode):
- msg = msg.encode('utf-8', 'replace')
- warnings.warn(err.Warning(*w[1:3]), stacklevel=4)
def __iter__(self):
- return iter(self.fetchone, None)
-
- Warning = err.Warning
- Error = err.Error
- InterfaceError = err.InterfaceError
- DatabaseError = err.DatabaseError
- DataError = err.DataError
- OperationalError = err.OperationalError
- IntegrityError = err.IntegrityError
- InternalError = err.InternalError
- ProgrammingError = err.ProgrammingError
- NotSupportedError = err.NotSupportedError
+ return self
+ def __next__(self):
+ row = self.fetchone()
+ if row is None:
+ raise StopIteration
+ return row
-class DictCursorMixin(object):
+ def __getattr__(self, name):
+ # DB-API 2.0 optional extension says these errors can be accessed
+ # via Connection object. But MySQLdb had defined them on Cursor object.
+ if name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
+ ):
+ # Deprecated since v1.1
+ warnings.warn(
+ "PyMySQL errors hould be accessed from `pymysql` package",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return getattr(err, name)
+ raise AttributeError(name)
+
+
+class DictCursorMixin:
# You can override this to use OrderedDict or other dict-like types.
dict_type = dict
def _do_get_result(self):
- super(DictCursorMixin, self)._do_get_result()
+ super()._do_get_result()
fields = []
if self.description:
for f in self._result.fields:
name = f.name
if name in fields:
- name = f.table_name + '.' + name
+ name = f.table_name + "." + name
fields.append(name)
self._fields = fields
@@ -427,8 +422,6 @@ class SSCursor(Cursor):
possible to scroll backwards, as only the current row is held in memory.
"""
- _defer_warnings = True
-
def _conv_row(self, row):
return row
@@ -450,7 +443,6 @@ def close(self):
def _query(self, q):
conn = self._get_db()
- self._last_executed = q
self._clear_result()
conn.query(q, unbuffered=True)
self._do_get_result()
@@ -460,15 +452,15 @@ def nextset(self):
return self._nextset(unbuffered=True)
def read_next(self):
- """Read next row"""
+ """Read next row."""
return self._conv_row(self._result._read_rowdata_packet_unbuffered())
def fetchone(self):
- """Fetch next row"""
+ """Fetch next row."""
self._check_executed()
row = self.read_next()
if row is None:
- self._show_warnings()
+ self.warning_count = self._result.warning_count
return None
self.rownumber += 1
return row
@@ -489,43 +481,46 @@ def fetchall_unbuffered(self):
"""
return iter(self.fetchone, None)
- def __iter__(self):
- return self.fetchall_unbuffered()
-
def fetchmany(self, size=None):
- """Fetch many"""
+ """Fetch many."""
self._check_executed()
if size is None:
size = self.arraysize
rows = []
- for i in range_type(size):
+ for i in range(size):
row = self.read_next()
if row is None:
- self._show_warnings()
+ self.warning_count = self._result.warning_count
break
rows.append(row)
self.rownumber += 1
+ if not rows:
+ # Django expects () for EOF.
+ # https://github.com/django/django/blob/0c1518ee429b01c145cf5b34eab01b0b92f8c246/django/db/backends/mysql/features.py#L8
+ return ()
return rows
- def scroll(self, value, mode='relative'):
+ def scroll(self, value, mode="relative"):
self._check_executed()
- if mode == 'relative':
+ if mode == "relative":
if value < 0:
raise err.NotSupportedError(
- "Backwards scrolling not supported by this cursor")
+ "Backwards scrolling not supported by this cursor"
+ )
- for _ in range_type(value):
+ for _ in range(value):
self.read_next()
self.rownumber += value
- elif mode == 'absolute':
+ elif mode == "absolute":
if value < self.rownumber:
raise err.NotSupportedError(
- "Backwards scrolling not supported by this cursor")
+ "Backwards scrolling not supported by this cursor"
+ )
end = value - self.rownumber
- for _ in range_type(end):
+ for _ in range(end):
self.read_next()
self.rownumber = value
else:
diff --git a/pymysql/err.py b/pymysql/err.py
index fbc60558e..dac65d3be 100644
--- a/pymysql/err.py
+++ b/pymysql/err.py
@@ -74,36 +74,77 @@ def _map_error(exc, *errors):
error_map[error] = exc
-_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
- ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
- ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
- ER.INVALID_GROUP_FUNC_USE, ER.UNSUPPORTED_EXTENSION,
- ER.TABLE_MUST_HAVE_COLUMNS, ER.CANT_DO_THIS_DURING_AN_TRANSACTION,
- ER.WRONG_DB_NAME, ER.WRONG_COLUMN_NAME,
- )
-_map_error(DataError, ER.WARN_DATA_TRUNCATED, ER.WARN_NULL_TO_NOTNULL,
- ER.WARN_DATA_OUT_OF_RANGE, ER.NO_DEFAULT, ER.PRIMARY_CANT_HAVE_NULL,
- ER.DATA_TOO_LONG, ER.DATETIME_FUNCTION_OVERFLOW)
-_map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
- ER.NO_REFERENCED_ROW_2, ER.ROW_IS_REFERENCED, ER.ROW_IS_REFERENCED_2,
- ER.CANNOT_ADD_FOREIGN, ER.BAD_NULL_ERROR)
-_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
- ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
-_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
- ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR,
- ER.COLUMNACCESS_DENIED_ERROR, ER.CONSTRAINT_FAILED, ER.LOCK_DEADLOCK)
+_map_error(
+ ProgrammingError,
+ ER.DB_CREATE_EXISTS,
+ ER.SYNTAX_ERROR,
+ ER.PARSE_ERROR,
+ ER.NO_SUCH_TABLE,
+ ER.WRONG_DB_NAME,
+ ER.WRONG_TABLE_NAME,
+ ER.FIELD_SPECIFIED_TWICE,
+ ER.INVALID_GROUP_FUNC_USE,
+ ER.UNSUPPORTED_EXTENSION,
+ ER.TABLE_MUST_HAVE_COLUMNS,
+ ER.CANT_DO_THIS_DURING_AN_TRANSACTION,
+ ER.WRONG_DB_NAME,
+ ER.WRONG_COLUMN_NAME,
+)
+_map_error(
+ DataError,
+ ER.WARN_DATA_TRUNCATED,
+ ER.WARN_NULL_TO_NOTNULL,
+ ER.WARN_DATA_OUT_OF_RANGE,
+ ER.NO_DEFAULT,
+ ER.PRIMARY_CANT_HAVE_NULL,
+ ER.DATA_TOO_LONG,
+ ER.DATETIME_FUNCTION_OVERFLOW,
+ ER.TRUNCATED_WRONG_VALUE_FOR_FIELD,
+ ER.ILLEGAL_VALUE_FOR_TYPE,
+)
+_map_error(
+ IntegrityError,
+ ER.DUP_ENTRY,
+ ER.NO_REFERENCED_ROW,
+ ER.NO_REFERENCED_ROW_2,
+ ER.ROW_IS_REFERENCED,
+ ER.ROW_IS_REFERENCED_2,
+ ER.CANNOT_ADD_FOREIGN,
+ ER.BAD_NULL_ERROR,
+)
+_map_error(
+ NotSupportedError,
+ ER.WARNING_NOT_COMPLETE_ROLLBACK,
+ ER.NOT_SUPPORTED_YET,
+ ER.FEATURE_DISABLED,
+ ER.UNKNOWN_STORAGE_ENGINE,
+)
+_map_error(
+ OperationalError,
+ ER.DBACCESS_DENIED_ERROR,
+ ER.ACCESS_DENIED_ERROR,
+ ER.CON_COUNT_ERROR,
+ ER.TABLEACCESS_DENIED_ERROR,
+ ER.COLUMNACCESS_DENIED_ERROR,
+ ER.CONSTRAINT_FAILED,
+ ER.LOCK_DEADLOCK,
+)
del _map_error, ER
def raise_mysql_exception(data):
- errno = struct.unpack('= 2 and value[0] == value[-1] == quote:
return value[1:-1]
return value
+ def optionxform(self, key):
+ return key.lower().replace("_", "-")
+
def get(self, section, option):
value = configparser.RawConfigParser.get(self, section, option)
return self.__remove_quotes(value)
diff --git a/pymysql/protocol.py b/pymysql/protocol.py
index 8ccf7c4d7..98fde6d0c 100644
--- a/pymysql/protocol.py
+++ b/pymysql/protocol.py
@@ -1,12 +1,9 @@
# Python implementation of low level MySQL client-server protocol
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
-from __future__ import print_function
from .charset import MBLENGTH
-from ._compat import PY2, range_type
from .constants import FIELD_TYPE, SERVER_STATUS
from . import err
-from .util import byte2int
import struct
import sys
@@ -23,11 +20,9 @@
def dump_packet(data): # pragma: no cover
def printable(data):
- if 32 <= byte2int(data) < 127:
- if isinstance(data, int):
- return chr(data)
- return data
- return '.'
+ if 32 <= data < 127:
+ return chr(data)
+ return "."
try:
print("packet length:", len(data))
@@ -37,21 +32,25 @@ def printable(data):
print("-" * 66)
except ValueError:
pass
- dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
+ dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)]
for d in dump_data:
- print(' '.join("{:02X}".format(byte2int(x)) for x in d) +
- ' ' * (16 - len(d)) + ' ' * 2 +
- ''.join(printable(x) for x in d))
+ print(
+ " ".join(f"{x:02X}" for x in d)
+ + " " * (16 - len(d))
+ + " " * 2
+ + "".join(printable(x) for x in d)
+ )
print("-" * 66)
print()
-class MysqlPacket(object):
+class MysqlPacket:
"""Representation of a MySQL response packet.
Provides an interface for reading/parsing the packet results.
"""
- __slots__ = ('_position', '_data')
+
+ __slots__ = ("_position", "_data")
def __init__(self, data, encoding):
self._position = 0
@@ -62,11 +61,12 @@ def get_all_data(self):
def read(self, size):
"""Read the first 'size' bytes in packet and advance cursor past them."""
- result = self._data[self._position:(self._position+size)]
+ result = self._data[self._position : (self._position + size)]
if len(result) != size:
- error = ('Result length not requested length:\n'
- 'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
- % (size, len(result), self._position, len(self._data)))
+ error = (
+ "Result length not requested length:\n"
+ f"Expected={size}. Actual={len(result)}. Position: {self._position}. Data Length: {len(self._data)}"
+ )
if DEBUG:
print(error)
self.dump()
@@ -79,7 +79,7 @@ def read_all(self):
(Subsequent read() will return errors.)
"""
- result = self._data[self._position:]
+ result = self._data[self._position :]
self._position = None # ensure no subsequent read()
return result
@@ -87,8 +87,9 @@ def advance(self, length):
"""Advance the cursor in data buffer 'length' bytes."""
new_position = self._position + length
if new_position < 0 or new_position > len(self._data):
- raise Exception('Invalid advance amount (%s) for cursor. '
- 'Position=%s' % (length, new_position))
+ raise Exception(
+ f"Invalid advance amount ({length}) for cursor. Position={new_position}"
+ )
self._position = new_position
def rewind(self, position=0):
@@ -106,44 +107,38 @@ def get_bytes(self, position, length=1):
No error checking is done. If requesting outside end of buffer
an empty string (or string shorter than 'length') may be returned!
"""
- return self._data[position:(position+length)]
-
- if PY2:
- def read_uint8(self):
- result = ord(self._data[self._position])
- self._position += 1
- return result
- else:
- def read_uint8(self):
- result = self._data[self._position]
- self._position += 1
- return result
+ return self._data[position : (position + length)]
+
+ def read_uint8(self):
+ result = self._data[self._position]
+ self._position += 1
+ return result
def read_uint16(self):
- result = struct.unpack_from('= 7
+ return self._data[0] == 0 and len(self._data) >= 7
def is_eof_packet(self):
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
# Caution: \xFE may be LengthEncodedInteger.
# If \xFE is LengthEncodedInteger header, 8bytes followed.
- return self._data[0:1] == b'\xfe' and len(self._data) < 9
+ return self._data[0] == 0xFE and len(self._data) < 9
def is_auth_switch_request(self):
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
- return self._data[0:1] == b'\xfe'
+ return self._data[0] == 0xFE
def is_extra_auth_data(self):
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
- return self._data[0:1] == b'\x01'
+ return self._data[0] == 1
def is_resultset_packet(self):
- field_count = ord(self._data[0:1])
+ field_count = self._data[0]
return 1 <= field_count <= 250
def is_load_local_packet(self):
- return self._data[0:1] == b'\xfb'
+ return self._data[0] == 0xFB
def is_error_packet(self):
- return self._data[0:1] == b'\xff'
+ return self._data[0] == 0xFF
def check_error(self):
if self.is_error_packet():
- self.rewind()
- self.advance(1) # field_count == error (we already know that)
- errno = self.read_uint16()
- if DEBUG: print("errno =", errno)
- err.raise_mysql_exception(self._data)
+ self.raise_for_error()
+
+ def raise_for_error(self):
+ self.rewind()
+ self.advance(1) # field_count == error (we already know that)
+ errno = self.read_uint16()
+ if DEBUG:
+ print("errno =", errno)
+ err.raise_mysql_exception(self._data)
def dump(self):
dump_packet(self._data)
@@ -245,8 +244,13 @@ def _parse_field_descriptor(self, encoding):
self.org_table = self.read_length_coded_string().decode(encoding)
self.name = self.read_length_coded_string().decode(encoding)
self.org_name = self.read_length_coded_string().decode(encoding)
- self.charsetnr, self.length, self.type_code, self.flags, self.scale = (
- self.read_struct('= version_tuple
+ def get_mysql_vendor(self, conn):
+ server_version = conn.get_server_info()
+
+ if "MariaDB" in server_version:
+ return "mariadb"
+
+ return "mysql"
+
_connections = None
@property
@@ -55,10 +71,12 @@ def connect(self, **params):
p = self.databases[0].copy()
p.update(params)
conn = pymysql.connect(**p)
+
@self.addCleanup
def teardown():
if conn.open:
conn.close()
+
return conn
def _teardown_connections(self):
@@ -80,7 +98,7 @@ def safe_create_table(self, connection, tablename, ddl, cleanup=True):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- cursor.execute("drop table if exists `%s`" % (tablename,))
+ cursor.execute(f"drop table if exists `{tablename}`")
cursor.execute(ddl)
cursor.close()
if cleanup:
@@ -90,15 +108,5 @@ def drop_table(self, connection, tablename):
cursor = connection.cursor()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- cursor.execute("drop table if exists `%s`" % (tablename,))
+ cursor.execute(f"drop table if exists `{tablename}`")
cursor.close()
-
- def safe_gc_collect(self):
- """Ensure cycles are collected via gc.
-
- Runs additional times on non-CPython platforms.
-
- """
- gc.collect()
- if not CPYTHON:
- gc.collect()
diff --git a/pymysql/tests/test_DictCursor.py b/pymysql/tests/test_DictCursor.py
index 9a0d638b2..4e545792a 100644
--- a/pymysql/tests/test_DictCursor.py
+++ b/pymysql/tests/test_DictCursor.py
@@ -6,46 +6,50 @@
class TestDictCursor(base.PyMySQLTestCase):
- bob = {'name': 'bob', 'age': 21, 'DOB': datetime.datetime(1990, 2, 6, 23, 4, 56)}
- jim = {'name': 'jim', 'age': 56, 'DOB': datetime.datetime(1955, 5, 9, 13, 12, 45)}
- fred = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)}
+ bob = {"name": "bob", "age": 21, "DOB": datetime.datetime(1990, 2, 6, 23, 4, 56)}
+ jim = {"name": "jim", "age": 56, "DOB": datetime.datetime(1955, 5, 9, 13, 12, 45)}
+ fred = {"name": "fred", "age": 100, "DOB": datetime.datetime(1911, 9, 12, 1, 1, 1)}
cursor_type = pymysql.cursors.DictCursor
def setUp(self):
- super(TestDictCursor, self).setUp()
- self.conn = conn = self.connections[0]
+ super().setUp()
+ self.conn = conn = self.connect()
c = conn.cursor(self.cursor_type)
- # create a table ane some data to query
+ # create a table and some data to query
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists dictcursor")
# include in filterwarnings since for unbuffered dict cursor warning for lack of table
# will only be propagated at start of next execute() call
- c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
- data = [("bob", 21, "1990-02-06 23:04:56"),
- ("jim", 56, "1955-05-09 13:12:45"),
- ("fred", 100, "1911-09-12 01:01:01")]
+ c.execute(
+ """CREATE TABLE dictcursor (name char(20), age int , DOB datetime)"""
+ )
+ data = [
+ ("bob", 21, "1990-02-06 23:04:56"),
+ ("jim", 56, "1955-05-09 13:12:45"),
+ ("fred", 100, "1911-09-12 01:01:01"),
+ ]
c.executemany("insert into dictcursor values (%s,%s,%s)", data)
def tearDown(self):
c = self.conn.cursor()
c.execute("drop table dictcursor")
- super(TestDictCursor, self).tearDown()
+ super().tearDown()
def _ensure_cursor_expired(self, cursor):
pass
def test_DictCursor(self):
bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy()
- #all assert test compare to the structure as would come out from MySQLdb
+ # all assert test compare to the structure as would come out from MySQLdb
conn = self.conn
c = conn.cursor(self.cursor_type)
# try an update which should return no rows
c.execute("update dictcursor set age=20 where name='bob'")
- bob['age'] = 20
+ bob["age"] = 20
# pull back the single row dict for bob and check
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchone()
@@ -55,19 +59,23 @@ def test_DictCursor(self):
# same again, but via fetchall => tuple)
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchall()
- self.assertEqual([bob], r, "fetch a 1 row result via fetchall failed via DictCursor")
+ self.assertEqual(
+ [bob], r, "fetch a 1 row result via fetchall failed via DictCursor"
+ )
# same test again but iterate over the
c.execute("SELECT * from dictcursor where name='bob'")
for r in c:
- self.assertEqual(bob, r, "fetch a 1 row result via iteration failed via DictCursor")
+ self.assertEqual(
+ bob, r, "fetch a 1 row result via iteration failed via DictCursor"
+ )
# get all 3 row via fetchall
c.execute("SELECT * from dictcursor")
r = c.fetchall()
- self.assertEqual([bob,jim,fred], r, "fetchall failed via DictCursor")
- #same test again but do a list comprehension
+ self.assertEqual([bob, jim, fred], r, "fetchall failed via DictCursor")
+ # same test again but do a list comprehension
c.execute("SELECT * from dictcursor")
r = list(c)
- self.assertEqual([bob,jim,fred], r, "DictCursor should be iterable")
+ self.assertEqual([bob, jim, fred], r, "DictCursor should be iterable")
# get all 2 row via fetchmany
c.execute("SELECT * from dictcursor")
r = c.fetchmany(2)
@@ -75,12 +83,13 @@ def test_DictCursor(self):
self._ensure_cursor_expired(c)
def test_custom_dict(self):
- class MyDict(dict): pass
+ class MyDict(dict):
+ pass
class MyDictCursor(self.cursor_type):
dict_type = MyDict
- keys = ['name', 'age', 'DOB']
+ keys = ["name", "age", "DOB"]
bob = MyDict([(k, self.bob[k]) for k in keys])
jim = MyDict([(k, self.jim[k]) for k in keys])
fred = MyDict([(k, self.fred[k]) for k in keys])
@@ -93,18 +102,15 @@ class MyDictCursor(self.cursor_type):
cur.execute("SELECT * FROM dictcursor")
r = cur.fetchall()
- self.assertEqual([bob, jim, fred], r,
- "fetchall failed via MyDictCursor")
+ self.assertEqual([bob, jim, fred], r, "fetchall failed via MyDictCursor")
cur.execute("SELECT * FROM dictcursor")
r = list(cur)
- self.assertEqual([bob, jim, fred], r,
- "list failed via MyDictCursor")
+ self.assertEqual([bob, jim, fred], r, "list failed via MyDictCursor")
cur.execute("SELECT * FROM dictcursor")
r = cur.fetchmany(2)
- self.assertEqual([bob, jim], r,
- "list failed via MyDictCursor")
+ self.assertEqual([bob, jim], r, "list failed via MyDictCursor")
self._ensure_cursor_expired(cur)
@@ -114,6 +120,8 @@ class TestSSDictCursor(TestDictCursor):
def _ensure_cursor_expired(self, cursor):
list(cursor.fetchall_unbuffered())
+
if __name__ == "__main__":
import unittest
+
unittest.main()
diff --git a/pymysql/tests/test_SSCursor.py b/pymysql/tests/test_SSCursor.py
index 3bbfcfa41..d5e6e2bce 100644
--- a/pymysql/tests/test_SSCursor.py
+++ b/pymysql/tests/test_SSCursor.py
@@ -1,15 +1,9 @@
-import sys
-
-try:
- from pymysql.tests import base
- import pymysql.cursors
- from pymysql.constants import CLIENT
-except Exception:
- # For local testing from top-level directory, without installing
- sys.path.append('../pymysql')
- from pymysql.tests import base
- import pymysql.cursors
- from pymysql.constants import CLIENT
+import pytest
+
+from pymysql.tests import base
+import pymysql.cursors
+from pymysql.constants import CLIENT, ER
+
class TestSSCursor(base.PyMySQLTestCase):
def test_SSCursor(self):
@@ -17,35 +11,35 @@ def test_SSCursor(self):
conn = self.connect(client_flag=CLIENT.MULTI_STATEMENTS)
data = [
- ('America', '', 'America/Jamaica'),
- ('America', '', 'America/Los_Angeles'),
- ('America', '', 'America/Lima'),
- ('America', '', 'America/New_York'),
- ('America', '', 'America/Menominee'),
- ('America', '', 'America/Havana'),
- ('America', '', 'America/El_Salvador'),
- ('America', '', 'America/Costa_Rica'),
- ('America', '', 'America/Denver'),
- ('America', '', 'America/Detroit'),]
+ ("America", "", "America/Jamaica"),
+ ("America", "", "America/Los_Angeles"),
+ ("America", "", "America/Lima"),
+ ("America", "", "America/New_York"),
+ ("America", "", "America/Menominee"),
+ ("America", "", "America/Havana"),
+ ("America", "", "America/El_Salvador"),
+ ("America", "", "America/Costa_Rica"),
+ ("America", "", "America/Denver"),
+ ("America", "", "America/Detroit"),
+ ]
cursor = conn.cursor(pymysql.cursors.SSCursor)
# Create table
- cursor.execute('CREATE TABLE tz_data ('
- 'region VARCHAR(64),'
- 'zone VARCHAR(64),'
- 'name VARCHAR(64))')
+ cursor.execute(
+ "CREATE TABLE tz_data (region VARCHAR(64), zone VARCHAR(64), name VARCHAR(64))"
+ )
conn.begin()
# Test INSERT
for i in data:
- cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i)
- self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match')
+ cursor.execute("INSERT INTO tz_data VALUES (%s, %s, %s)", i)
+ self.assertEqual(conn.affected_rows(), 1, "affected_rows does not match")
conn.commit()
# Test fetchone()
iter = 0
- cursor.execute('SELECT * FROM tz_data')
+ cursor.execute("SELECT * FROM tz_data")
while True:
row = cursor.fetchone()
if row is None:
@@ -53,26 +47,35 @@ def test_SSCursor(self):
iter += 1
# Test cursor.rowcount
- self.assertEqual(cursor.rowcount, affected_rows,
- 'cursor.rowcount != %s' % (str(affected_rows)))
+ self.assertEqual(
+ cursor.rowcount,
+ affected_rows,
+ "cursor.rowcount != %s" % (str(affected_rows)),
+ )
# Test cursor.rownumber
- self.assertEqual(cursor.rownumber, iter,
- 'cursor.rowcount != %s' % (str(iter)))
+ self.assertEqual(
+ cursor.rownumber, iter, "cursor.rowcount != %s" % (str(iter))
+ )
# Test row came out the same as it went in
- self.assertEqual((row in data), True,
- 'Row not found in source data')
+ self.assertEqual((row in data), True, "Row not found in source data")
# Test fetchall
- cursor.execute('SELECT * FROM tz_data')
- self.assertEqual(len(cursor.fetchall()), len(data),
- 'fetchall failed. Number of rows does not match')
+ cursor.execute("SELECT * FROM tz_data")
+ self.assertEqual(
+ len(cursor.fetchall()),
+ len(data),
+ "fetchall failed. Number of rows does not match",
+ )
# Test fetchmany
- cursor.execute('SELECT * FROM tz_data')
- self.assertEqual(len(cursor.fetchmany(2)), 2,
- 'fetchmany failed. Number of rows does not match')
+ cursor.execute("SELECT * FROM tz_data")
+ self.assertEqual(
+ len(cursor.fetchmany(2)),
+ 2,
+ "fetchmany failed. Number of rows does not match",
+ )
# So MySQLdb won't throw "Commands out of sync"
while True:
@@ -81,30 +84,153 @@ def test_SSCursor(self):
break
# Test update, affected_rows()
- cursor.execute('UPDATE tz_data SET zone = %s', ['Foo'])
+ cursor.execute("UPDATE tz_data SET zone = %s", ["Foo"])
conn.commit()
- self.assertEqual(cursor.rowcount, len(data),
- 'Update failed. affected_rows != %s' % (str(len(data))))
+ self.assertEqual(
+ cursor.rowcount,
+ len(data),
+ "Update failed. affected_rows != %s" % (str(len(data))),
+ )
# Test executemany
- cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data)
- self.assertEqual(cursor.rowcount, len(data),
- 'executemany failed. cursor.rowcount != %s' % (str(len(data))))
+ cursor.executemany("INSERT INTO tz_data VALUES (%s, %s, %s)", data)
+ self.assertEqual(
+ cursor.rowcount,
+ len(data),
+ "executemany failed. cursor.rowcount != %s" % (str(len(data))),
+ )
# Test multiple datasets
- cursor.execute('SELECT 1; SELECT 2; SELECT 3')
- self.assertListEqual(list(cursor), [(1, )])
+ cursor.execute("SELECT 1; SELECT 2; SELECT 3")
+ self.assertListEqual(list(cursor), [(1,)])
self.assertTrue(cursor.nextset())
- self.assertListEqual(list(cursor), [(2, )])
+ self.assertListEqual(list(cursor), [(2,)])
self.assertTrue(cursor.nextset())
- self.assertListEqual(list(cursor), [(3, )])
+ self.assertListEqual(list(cursor), [(3,)])
self.assertFalse(cursor.nextset())
- cursor.execute('DROP TABLE IF EXISTS tz_data')
+ cursor.execute("DROP TABLE IF EXISTS tz_data")
cursor.close()
+ def test_execution_time_limit(self):
+ # this method is similarly implemented in test_cursor
+
+ conn = self.connect()
+
+ # table creation and filling is SSCursor only as it's not provided by self.setUp()
+ self.safe_create_table(
+ conn,
+ "test",
+ "create table test (data varchar(10))",
+ )
+ with conn.cursor() as cur:
+ cur.execute(
+ "insert into test (data) values "
+ "('row1'), ('row2'), ('row3'), ('row4'), ('row5')"
+ )
+ conn.commit()
+
+ db_type = self.get_mysql_vendor(conn)
+
+ with conn.cursor(pymysql.cursors.SSCursor) as cur:
+ # MySQL MAX_EXECUTION_TIME takes ms
+ # MariaDB max_statement_time takes seconds as int/float, introduced in 10.1
+
+ # this will sleep 0.01 seconds per row
+ if db_type == "mysql":
+ sql = (
+ "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test"
+ )
+ else:
+ sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test"
+
+ cur.execute(sql)
+ # unlike Cursor, SSCursor returns a list of tuples here
+ self.assertEqual(
+ cur.fetchall(),
+ [
+ ("row1", 0),
+ ("row2", 0),
+ ("row3", 0),
+ ("row4", 0),
+ ("row5", 0),
+ ],
+ )
+
+ if db_type == "mysql":
+ sql = (
+ "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test"
+ )
+ else:
+ sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test"
+ cur.execute(sql)
+ self.assertEqual(cur.fetchone(), ("row1", 0))
+
+ # this discards the previous unfinished query and raises an
+ # incomplete unbuffered query warning
+ with pytest.warns(UserWarning):
+ cur.execute("SELECT 1")
+ self.assertEqual(cur.fetchone(), (1,))
+
+ # SSCursor will not read the EOF packet until we try to read
+ # another row. Skipping this will raise an incomplete unbuffered
+ # query warning in the next cur.execute().
+ self.assertEqual(cur.fetchone(), None)
+
+ if db_type == "mysql":
+ sql = "SELECT /*+ MAX_EXECUTION_TIME(1) */ data, sleep(1) FROM test"
+ else:
+ sql = "SET STATEMENT max_statement_time=0.001 FOR SELECT data, sleep(1) FROM test"
+ with pytest.raises(pymysql.err.OperationalError) as cm:
+ # in an unbuffered cursor the OperationalError may not show up
+ # until fetching the entire result
+ cur.execute(sql)
+ cur.fetchall()
+
+ if db_type == "mysql":
+ # this constant was only introduced in MySQL 5.7, not sure
+ # what was returned before, may have been ER_QUERY_INTERRUPTED
+ self.assertEqual(cm.value.args[0], ER.QUERY_TIMEOUT)
+ else:
+ self.assertEqual(cm.value.args[0], ER.STATEMENT_TIMEOUT)
+
+ # connection should still be fine at this point
+ cur.execute("SELECT 1")
+ self.assertEqual(cur.fetchone(), (1,))
+
+ def test_warnings(self):
+ con = self.connect()
+ cur = con.cursor(pymysql.cursors.SSCursor)
+ cur.execute("DROP TABLE IF EXISTS `no_exists_table`")
+ self.assertEqual(cur.warning_count, 1)
+
+ cur.execute("SHOW WARNINGS")
+ w = cur.fetchone()
+ self.assertEqual(w[1], ER.BAD_TABLE_ERROR)
+ self.assertIn(
+ "no_exists_table",
+ w[2],
+ )
+
+ # ensure unbuffered result is finished
+ self.assertIsNone(cur.fetchone())
+
+ cur.execute("SELECT 1")
+ self.assertEqual(cur.fetchone(), (1,))
+ self.assertIsNone(cur.fetchone())
+
+ self.assertEqual(cur.warning_count, 0)
+
+ cur.execute("SELECT CAST('abc' AS SIGNED)")
+ # this ensures fully retrieving the unbuffered result
+ rows = cur.fetchmany(2)
+ self.assertEqual(len(rows), 1)
+ self.assertEqual(cur.warning_count, 1)
+
+
__all__ = ["TestSSCursor"]
if __name__ == "__main__":
import unittest
+
unittest.main()
diff --git a/pymysql/tests/test_basic.py b/pymysql/tests/test_basic.py
index a53373224..0fe13b59d 100644
--- a/pymysql/tests/test_basic.py
+++ b/pymysql/tests/test_basic.py
@@ -1,15 +1,11 @@
-# coding: utf-8
import datetime
import json
import time
-import warnings
-from unittest2 import SkipTest
+import pytest
-from pymysql import util
import pymysql.cursors
from pymysql.tests import base
-from pymysql.err import ProgrammingError
__all__ = ["TestConversion", "TestCursor", "TestBulkInserts"]
@@ -17,26 +13,66 @@
class TestConversion(base.PyMySQLTestCase):
def test_datatypes(self):
- """ test every data type """
- conn = self.connections[0]
+ """test every data type"""
+ conn = self.connect()
c = conn.cursor()
- c.execute("create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)")
+ c.execute(
+ """
+create table test_datatypes (
+ b bit,
+ i int,
+ l bigint,
+ f real,
+ s varchar(32),
+ u varchar(32),
+ bb blob,
+ d date,
+ dt datetime,
+ ts timestamp,
+ td time,
+ t time,
+ st datetime)
+"""
+ )
try:
# insert values
- v = (True, -3, 123456789012, 5.7, "hello'\" world", u"Espa\xc3\xb1ol", "binary\x00data".encode(conn.encoding), datetime.date(1988,2,2), datetime.datetime(2014, 5, 15, 7, 45, 57), datetime.timedelta(5,6), datetime.time(16,32), time.localtime())
- c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v)
+ v = (
+ True,
+ -3,
+ 123456789012,
+ 5.7,
+ "hello'\" world",
+ "Espa\xc3\xb1ol",
+ "binary\x00data".encode(conn.encoding),
+ datetime.date(1988, 2, 2),
+ datetime.datetime(2014, 5, 15, 7, 45, 57),
+ datetime.timedelta(5, 6),
+ datetime.time(16, 32),
+ time.localtime(),
+ )
+ c.execute(
+ "insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values"
+ " (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
+ v,
+ )
c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
r = c.fetchone()
- self.assertEqual(util.int2byte(1), r[0])
+ self.assertEqual(b"\x01", r[0])
self.assertEqual(v[1:10], r[1:10])
- self.assertEqual(datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10])
+ self.assertEqual(
+ datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10]
+ )
self.assertEqual(datetime.datetime(*v[-1][:6]), r[-1])
c.execute("delete from test_datatypes")
# check nulls
- c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", [None] * 12)
+ c.execute(
+ "insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st)"
+ " values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
+ [None] * 12,
+ )
c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
r = c.fetchone()
self.assertEqual(tuple([None] * 12), r)
@@ -45,30 +81,37 @@ def test_datatypes(self):
# check sequences type
for seq_type in (tuple, list, set, frozenset):
- c.execute("insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)")
- seq = seq_type([2,6])
- c.execute("select l from test_datatypes where i in %s order by i", (seq,))
+ c.execute(
+ "insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)"
+ )
+ seq = seq_type([2, 6])
+ c.execute(
+ "select l from test_datatypes where i in %s order by i", (seq,)
+ )
r = c.fetchall()
- self.assertEqual(((4,),(8,)), r)
+ self.assertEqual(((4,), (8,)), r)
c.execute("delete from test_datatypes")
finally:
c.execute("drop table test_datatypes")
def test_dict(self):
- """ test dict escaping """
- conn = self.connections[0]
+ """test dict escaping"""
+ conn = self.connect()
c = conn.cursor()
c.execute("create table test_dict (a integer, b integer, c integer)")
try:
- c.execute("insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", {"a":1,"b":2,"c":3})
+ c.execute(
+ "insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)",
+ {"a": 1, "b": 2, "c": 3},
+ )
c.execute("select a,b,c from test_dict")
- self.assertEqual((1,2,3), c.fetchone())
+ self.assertEqual((1, 2, 3), c.fetchone())
finally:
c.execute("drop table test_dict")
def test_string(self):
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
c.execute("create table test_dict (a text)")
test_value = "I am a test string"
@@ -80,7 +123,7 @@ def test_string(self):
c.execute("drop table test_dict")
def test_integer(self):
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
c.execute("create table test_dict (a integer)")
test_value = 12345
@@ -94,9 +137,10 @@ def test_integer(self):
def test_binary(self):
"""test binary data"""
data = bytes(bytearray(range(255)))
- conn = self.connections[0]
+ conn = self.connect()
self.safe_create_table(
- conn, "test_binary", "create table test_binary (b binary(255))")
+ conn, "test_binary", "create table test_binary (b binary(255))"
+ )
with conn.cursor() as c:
c.execute("insert into test_binary (b) values (_binary %s)", (data,))
@@ -106,9 +150,8 @@ def test_binary(self):
def test_blob(self):
"""test blob data"""
data = bytes(bytearray(range(256)) * 4)
- conn = self.connections[0]
- self.safe_create_table(
- conn, "test_blob", "create table test_blob (b blob)")
+ conn = self.connect()
+ self.safe_create_table(conn, "test_blob", "create table test_blob (b blob)")
with conn.cursor() as c:
c.execute("insert into test_blob (b) values (_binary %s)", (data,))
@@ -116,42 +159,44 @@ def test_blob(self):
self.assertEqual(data, c.fetchone()[0])
def test_untyped(self):
- """ test conversion of null, empty string """
- conn = self.connections[0]
+ """test conversion of null, empty string"""
+ conn = self.connect()
c = conn.cursor()
c.execute("select null,''")
- self.assertEqual((None,u''), c.fetchone())
+ self.assertEqual((None, ""), c.fetchone())
c.execute("select '',null")
- self.assertEqual((u'',None), c.fetchone())
+ self.assertEqual(("", None), c.fetchone())
def test_timedelta(self):
- """ test timedelta conversion """
- conn = self.connections[0]
+ """test timedelta conversion"""
+ conn = self.connect()
c = conn.cursor()
- c.execute("select time('12:30'), time('23:12:59'), time('23:12:59.05100'), time('-12:30'), time('-23:12:59'), time('-23:12:59.05100'), time('-00:30')")
- self.assertEqual((datetime.timedelta(0, 45000),
- datetime.timedelta(0, 83579),
- datetime.timedelta(0, 83579, 51000),
- -datetime.timedelta(0, 45000),
- -datetime.timedelta(0, 83579),
- -datetime.timedelta(0, 83579, 51000),
- -datetime.timedelta(0, 1800)),
- c.fetchone())
+ c.execute(
+ "select time('12:30'), time('23:12:59'), time('23:12:59.05100'),"
+ + " time('-12:30'), time('-23:12:59'), time('-23:12:59.05100'), time('-00:30')"
+ )
+ self.assertEqual(
+ (
+ datetime.timedelta(0, 45000),
+ datetime.timedelta(0, 83579),
+ datetime.timedelta(0, 83579, 51000),
+ -datetime.timedelta(0, 45000),
+ -datetime.timedelta(0, 83579),
+ -datetime.timedelta(0, 83579, 51000),
+ -datetime.timedelta(0, 1800),
+ ),
+ c.fetchone(),
+ )
def test_datetime_microseconds(self):
- """ test datetime conversion w microseconds"""
+ """test datetime conversion w microseconds"""
- conn = self.connections[0]
- if not self.mysql_server_is(conn, (5, 6, 4)):
- raise SkipTest("target backend does not support microseconds")
+ conn = self.connect()
c = conn.cursor()
dt = datetime.datetime(2013, 11, 12, 9, 9, 9, 123450)
c.execute("create table test_datetime (id int, ts datetime(6))")
try:
- c.execute(
- "insert into test_datetime values (%s, %s)",
- (1, dt)
- )
+ c.execute("insert into test_datetime values (%s, %s)", (1, dt))
c.execute("select ts from test_datetime")
self.assertEqual((dt,), c.fetchone())
finally:
@@ -164,7 +209,7 @@ class TestCursor(base.PyMySQLTestCase):
# compatible with the DB-API 2.0 spec and has not broken
# any unit tests for anything we've tried.
- #def test_description(self):
+ # def test_description(self):
# """ test description attribute """
# # result is from MySQLdb module
# r = (('Host', 254, 11, 60, 60, 0, 0),
@@ -206,15 +251,15 @@ class TestCursor(base.PyMySQLTestCase):
# ('max_updates', 3, 1, 11, 11, 0, 0),
# ('max_connections', 3, 1, 11, 11, 0, 0),
# ('max_user_connections', 3, 1, 11, 11, 0, 0))
- # conn = self.connections[0]
+ # conn = self.connect()
# c = conn.cursor()
# c.execute("select * from mysql.user")
#
# self.assertEqual(r, c.description)
def test_fetch_no_result(self):
- """ test a fetchone() with no rows """
- conn = self.connections[0]
+ """test a fetchone() with no rows"""
+ conn = self.connect()
c = conn.cursor()
c.execute("create table test_nr (b varchar(32))")
try:
@@ -225,26 +270,26 @@ def test_fetch_no_result(self):
c.execute("drop table test_nr")
def test_aggregates(self):
- """ test aggregate functions """
- conn = self.connections[0]
+ """test aggregate functions"""
+ conn = self.connect()
c = conn.cursor()
try:
- c.execute('create table test_aggregates (i integer)')
+ c.execute("create table test_aggregates (i integer)")
for i in range(0, 10):
- c.execute('insert into test_aggregates (i) values (%s)', (i,))
- c.execute('select sum(i) from test_aggregates')
- r, = c.fetchone()
- self.assertEqual(sum(range(0,10)), r)
+ c.execute("insert into test_aggregates (i) values (%s)", (i,))
+ c.execute("select sum(i) from test_aggregates")
+ (r,) = c.fetchone()
+ self.assertEqual(sum(range(0, 10)), r)
finally:
- c.execute('drop table test_aggregates')
+ c.execute("drop table test_aggregates")
def test_single_tuple(self):
- """ test a single tuple """
- conn = self.connections[0]
+ """test a single tuple"""
+ conn = self.connect()
c = conn.cursor()
self.safe_create_table(
- conn, 'mystuff',
- "create table mystuff (id integer primary key)")
+ conn, "mystuff", "create table mystuff (id integer primary key)"
+ )
c.execute("insert into mystuff (id) values (1)")
c.execute("insert into mystuff (id) values (2)")
c.execute("select id from mystuff where id in %s", ((1,),))
@@ -255,82 +300,102 @@ def test_json(self):
args = self.databases[0].copy()
args["charset"] = "utf8mb4"
conn = pymysql.connect(**args)
+ # MariaDB only has limited JSON support, stores data as longtext
+ # https://mariadb.com/kb/en/json-data-type/
if not self.mysql_server_is(conn, (5, 7, 0)):
- raise SkipTest("JSON type is not supported on MySQL <= 5.6")
+ pytest.skip("JSON type is only supported on MySQL >= 5.7")
- self.safe_create_table(conn, "test_json", """\
+ self.safe_create_table(
+ conn,
+ "test_json",
+ """\
create table test_json (
id int not null,
json JSON not null,
primary key (id)
-);""")
+);""",
+ )
cur = conn.cursor()
- json_str = u'{"hello": "ãããĢãĄã¯"}'
+ json_str = '{"hello": "ãããĢãĄã¯"}'
cur.execute("INSERT INTO test_json (id, `json`) values (42, %s)", (json_str,))
cur.execute("SELECT `json` from `test_json` WHERE `id`=42")
res = cur.fetchone()[0]
self.assertEqual(json.loads(res), json.loads(json_str))
- cur.execute("SELECT CAST(%s AS JSON) AS x", (json_str,))
- res = cur.fetchone()[0]
- self.assertEqual(json.loads(res), json.loads(json_str))
+ if self.get_mysql_vendor(conn) == "mysql":
+ cur.execute("SELECT CAST(%s AS JSON) AS x", (json_str,))
+ res = cur.fetchone()[0]
+ self.assertEqual(json.loads(res), json.loads(json_str))
class TestBulkInserts(base.PyMySQLTestCase):
-
cursor_type = pymysql.cursors.DictCursor
def setUp(self):
- super(TestBulkInserts, self).setUp()
- self.conn = conn = self.connections[0]
- c = conn.cursor(self.cursor_type)
+ super().setUp()
+ self.conn = conn = self.connect()
- # create a table ane some data to query
- self.safe_create_table(conn, 'bulkinsert', """\
+ # create a table and some data to query
+ self.safe_create_table(
+ conn,
+ "bulkinsert",
+ """\
CREATE TABLE bulkinsert
(
-id int(11),
+id int,
name char(20),
age int,
height int,
PRIMARY KEY (id)
)
-""")
+""",
+ )
def _verify_records(self, data):
- conn = self.connections[0]
+ conn = self.connect()
cursor = conn.cursor()
cursor.execute("SELECT id, name, age, height from bulkinsert")
result = cursor.fetchall()
self.assertEqual(sorted(data), sorted(result))
def test_bulk_insert(self):
- conn = self.connections[0]
+ conn = self.connect()
cursor = conn.cursor()
data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)]
- cursor.executemany("insert into bulkinsert (id, name, age, height) "
- "values (%s,%s,%s,%s)", data)
+ cursor.executemany(
+ "insert into bulkinsert (id, name, age, height) values (%s,%s,%s,%s)",
+ data,
+ )
self.assertEqual(
- cursor._last_executed, bytearray(
- b"insert into bulkinsert (id, name, age, height) values "
- b"(0,'bob',21,123),(1,'jim',56,45),(2,'fred',100,180)"))
- cursor.execute('commit')
+ cursor._executed,
+ bytearray(
+ b"insert into bulkinsert (id, name, age, height) values "
+ b"(0,'bob',21,123),(1,'jim',56,45),(2,'fred',100,180)"
+ ),
+ )
+ cursor.execute("commit")
self._verify_records(data)
def test_bulk_insert_multiline_statement(self):
- conn = self.connections[0]
+ conn = self.connect()
cursor = conn.cursor()
data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)]
- cursor.executemany("""insert
+ cursor.executemany(
+ """insert
into bulkinsert (id, name,
age, height)
values (%s,
%s , %s,
%s )
- """, data)
- self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert
+ """,
+ data,
+ )
+ self.assertEqual(
+ cursor._executed.strip(),
+ bytearray(
+ b"""insert
into bulkinsert (id, name,
age, height)
values (0,
@@ -339,33 +404,43 @@ def test_bulk_insert_multiline_statement(self):
'jim' , 56,
45 ),(2,
'fred' , 100,
-180 )"""))
- cursor.execute('commit')
+180 )"""
+ ),
+ )
+ cursor.execute("commit")
self._verify_records(data)
def test_bulk_insert_single_record(self):
- conn = self.connections[0]
+ conn = self.connect()
cursor = conn.cursor()
data = [(0, "bob", 21, 123)]
- cursor.executemany("insert into bulkinsert (id, name, age, height) "
- "values (%s,%s,%s,%s)", data)
- cursor.execute('commit')
+ cursor.executemany(
+ "insert into bulkinsert (id, name, age, height) values (%s,%s,%s,%s)",
+ data,
+ )
+ cursor.execute("commit")
self._verify_records(data)
def test_issue_288(self):
- """executemany should work with "insert ... on update" """
- conn = self.connections[0]
+ """executemany should work with "insert ... on update"""
+ conn = self.connect()
cursor = conn.cursor()
data = [(0, "bob", 21, 123), (1, "jim", 56, 45), (2, "fred", 100, 180)]
- cursor.executemany("""insert
+ cursor.executemany(
+ """insert
into bulkinsert (id, name,
age, height)
values (%s,
%s , %s,
%s ) on duplicate key update
age = values(age)
- """, data)
- self.assertEqual(cursor._last_executed.strip(), bytearray(b"""insert
+ """,
+ data,
+ )
+ self.assertEqual(
+ cursor._executed.strip(),
+ bytearray(
+ b"""insert
into bulkinsert (id, name,
age, height)
values (0,
@@ -375,17 +450,8 @@ def test_issue_288(self):
45 ),(2,
'fred' , 100,
180 ) on duplicate key update
-age = values(age)"""))
- cursor.execute('commit')
+age = values(age)"""
+ ),
+ )
+ cursor.execute("commit")
self._verify_records(data)
-
- def test_warnings(self):
- con = self.connections[0]
- cur = con.cursor()
- with warnings.catch_warnings(record=True) as ws:
- warnings.simplefilter("always")
- cur.execute("drop table if exists no_exists_table")
- self.assertEqual(len(ws), 1)
- self.assertEqual(ws[0].category, pymysql.Warning)
- if u"no_exists_table" not in str(ws[0].message):
- self.fail("'no_exists_table' not in %s" % (str(ws[0].message),))
diff --git a/pymysql/tests/test_charset.py b/pymysql/tests/test_charset.py
new file mode 100644
index 000000000..94e6e1559
--- /dev/null
+++ b/pymysql/tests/test_charset.py
@@ -0,0 +1,25 @@
+import pymysql.charset
+
+
+def test_utf8():
+ utf8mb3 = pymysql.charset.charset_by_name("utf8mb3")
+ assert utf8mb3.name == "utf8mb3"
+ assert utf8mb3.collation == "utf8mb3_general_ci"
+ assert (
+ repr(utf8mb3)
+ == "Charset(id=33, name='utf8mb3', collation='utf8mb3_general_ci')"
+ )
+
+ # MySQL 8.0 changed the default collation for utf8mb4.
+ # But we use old default for compatibility.
+ utf8mb4 = pymysql.charset.charset_by_name("utf8mb4")
+ assert utf8mb4.name == "utf8mb4"
+ assert utf8mb4.collation == "utf8mb4_general_ci"
+ assert (
+ repr(utf8mb4)
+ == "Charset(id=45, name='utf8mb4', collation='utf8mb4_general_ci')"
+ )
+
+ # utf8 is alias of utf8mb4 since MySQL 8.0, and PyMySQL v1.1.
+ utf8 = pymysql.charset.charset_by_name("utf8")
+ assert utf8 == utf8mb4
diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py
index 5e95b1c8c..dcf3394c1 100644
--- a/pymysql/tests/test_connection.py
+++ b/pymysql/tests/test_connection.py
@@ -1,10 +1,11 @@
import datetime
-import sys
+import ssl
+import pytest
import time
-import unittest2
+from unittest import mock
+
import pymysql
from pymysql.tests import base
-from pymysql._compat import text_type
from pymysql.constants import CLIENT
@@ -27,7 +28,7 @@ def __init__(self, c, user, db, auth=None, authdata=None, password=None):
# already exists - TODO need to check the same plugin applies
self._created = False
try:
- c.execute("GRANT SELECT ON %s.* TO %s" % (db, user))
+ c.execute(f"GRANT SELECT ON {db}.* TO {user}")
self._grant = True
except pymysql.err.InternalError:
self._grant = False
@@ -37,13 +38,12 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
if self._grant:
- self._c.execute("REVOKE SELECT ON %s.* FROM %s" % (self._db, self._user))
+ self._c.execute(f"REVOKE SELECT ON {self._db}.* FROM {self._user}")
if self._created:
self._c.execute("DROP USER %s" % self._user)
class TestAuthentication(base.PyMySQLTestCase):
-
socket_auth = False
socket_found = False
two_questions_found = False
@@ -51,36 +51,40 @@ class TestAuthentication(base.PyMySQLTestCase):
pam_found = False
mysql_old_password_found = False
sha256_password_found = False
+ ed25519_found = False
import os
- osuser = os.environ.get('USER')
+
+ osuser = os.environ.get("USER")
# socket auth requires the current user and for the connection to be a socket
# rest do grants @localhost due to incomplete logic - TODO change to @% then
db = base.PyMySQLTestCase.databases[0].copy()
- socket_auth = db.get('unix_socket') is not None \
- and db.get('host') in ('localhost', '127.0.0.1')
+ socket_auth = db.get("unix_socket") is not None and db.get("host") in (
+ "localhost",
+ "127.0.0.1",
+ )
cur = pymysql.connect(**db).cursor()
- del db['user']
+ del db["user"]
cur.execute("SHOW PLUGINS")
for r in cur:
- if (r[1], r[2]) != (u'ACTIVE', u'AUTHENTICATION'):
+ if (r[1], r[2]) != ("ACTIVE", "AUTHENTICATION"):
continue
- if r[3] == u'auth_socket.so':
+ if r[3] == "auth_socket.so" or r[0] == "unix_socket":
socket_plugin_name = r[0]
socket_found = True
- elif r[3] == u'dialog_examples.so':
- if r[0] == 'two_questions':
- two_questions_found = True
- elif r[0] == 'three_attempts':
- three_attempts_found = True
- elif r[0] == u'pam':
+ elif r[3] == "dialog_examples.so":
+ if r[0] == "two_questions":
+ two_questions_found = True
+ elif r[0] == "three_attempts":
+ three_attempts_found = True
+ elif r[0] == "pam":
pam_found = True
- pam_plugin_name = r[3].split('.')[0]
- if pam_plugin_name == 'auth_pam':
- pam_plugin_name = 'pam'
+ pam_plugin_name = r[3].split(".")[0]
+ if pam_plugin_name == "auth_pam":
+ pam_plugin_name = "pam"
# MySQL: authentication_pam
# https://dev.mysql.com/doc/refman/5.5/en/pam-authentication-plugin.html
@@ -88,306 +92,385 @@ class TestAuthentication(base.PyMySQLTestCase):
# https://mariadb.com/kb/en/mariadb/pam-authentication-plugin/
# Names differ but functionality is close
- elif r[0] == u'mysql_old_password':
+ elif r[0] == "mysql_old_password":
mysql_old_password_found = True
- elif r[0] == u'sha256_password':
+ elif r[0] == "sha256_password":
sha256_password_found = True
- #else:
+ elif r[0] == "ed25519":
+ ed25519_found = True
+ # else:
# print("plugin: %r" % r[0])
def test_plugin(self):
- if not self.mysql_server_is(self.connections[0], (5, 5, 0)):
- raise unittest2.SkipTest("MySQL-5.5 required for plugins")
- cur = self.connections[0].cursor()
- cur.execute("select plugin from mysql.user where concat(user, '@', host)=current_user()")
+ conn = self.connect()
+ cur = conn.cursor()
+ cur.execute(
+ "select plugin from mysql.user where concat(user, '@', host)=current_user()"
+ )
for r in cur:
- self.assertIn(self.connections[0]._auth_plugin_name, (r[0], 'mysql_native_password'))
+ self.assertIn(conn._auth_plugin_name, (r[0], "mysql_native_password"))
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipIf(socket_found, "socket plugin already installed")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(socket_found, reason="socket plugin already installed")
def testSocketAuthInstallPlugin(self):
# needs plugin. lets install it.
- cur = self.connections[0].cursor()
+ cur = self.connect().cursor()
try:
cur.execute("install plugin auth_socket soname 'auth_socket.so'")
TestAuthentication.socket_found = True
- self.socket_plugin_name = 'auth_socket'
+ self.socket_plugin_name = "auth_socket"
self.realtestSocketAuth()
except pymysql.err.InternalError:
try:
cur.execute("install soname 'auth_socket'")
TestAuthentication.socket_found = True
- self.socket_plugin_name = 'unix_socket'
+ self.socket_plugin_name = "unix_socket"
self.realtestSocketAuth()
except pymysql.err.InternalError:
TestAuthentication.socket_found = False
- raise unittest2.SkipTest('we couldn\'t install the socket plugin')
+ pytest.skip("we couldn't install the socket plugin")
finally:
if TestAuthentication.socket_found:
cur.execute("uninstall plugin %s" % self.socket_plugin_name)
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipUnless(socket_found, "no socket plugin")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(not socket_found, reason="no socket plugin")
def testSocketAuth(self):
self.realtestSocketAuth()
def realtestSocketAuth(self):
- with TempUser(self.connections[0].cursor(), TestAuthentication.osuser + '@localhost',
- self.databases[0]['db'], self.socket_plugin_name) as u:
- c = pymysql.connect(user=TestAuthentication.osuser, **self.db)
+ with TempUser(
+ self.connect().cursor(),
+ TestAuthentication.osuser + "@localhost",
+ self.databases[0]["database"],
+ self.socket_plugin_name,
+ ):
+ pymysql.connect(user=TestAuthentication.osuser, **self.db)
- class Dialog(object):
- fail=False
+ class Dialog:
+ fail = False
def __init__(self, con):
- self.fail=TestAuthentication.Dialog.fail
+ self.fail = TestAuthentication.Dialog.fail
pass
def prompt(self, echo, prompt):
if self.fail:
- self.fail=False
- return b'bad guess at a password'
+ self.fail = False
+ return b"bad guess at a password"
return self.m.get(prompt)
- class DialogHandler(object):
-
+ class DialogHandler:
def __init__(self, con):
- self.con=con
+ self.con = con
def authenticate(self, pkt):
while True:
flag = pkt.read_uint8()
- echo = (flag & 0x06) == 0x02
+ # echo = (flag & 0x06) == 0x02
last = (flag & 0x01) == 0x01
prompt = pkt.read_all()
- if prompt == b'Password, please:':
- self.con.write_packet(b'stillnotverysecret\0')
+ if prompt == b"Password, please:":
+ self.con.write_packet(b"stillnotverysecret\0")
else:
- self.con.write_packet(b'no idea what to do with this prompt\0')
+ self.con.write_packet(b"no idea what to do with this prompt\0")
pkt = self.con._read_packet()
pkt.check_error()
if pkt.is_ok_packet() or last:
break
return pkt
- class DefectiveHandler(object):
+ class DefectiveHandler:
def __init__(self, con):
- self.con=con
-
+ self.con = con
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipIf(two_questions_found, "two_questions plugin already installed")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(
+ two_questions_found, reason="two_questions plugin already installed"
+ )
def testDialogAuthTwoQuestionsInstallPlugin(self):
# needs plugin. lets install it.
- cur = self.connections[0].cursor()
+ cur = self.connect().cursor()
try:
cur.execute("install plugin two_questions soname 'dialog_examples.so'")
TestAuthentication.two_questions_found = True
self.realTestDialogAuthTwoQuestions()
except pymysql.err.InternalError:
- raise unittest2.SkipTest('we couldn\'t install the two_questions plugin')
+ pytest.skip("we couldn't install the two_questions plugin")
finally:
if TestAuthentication.two_questions_found:
cur.execute("uninstall plugin two_questions")
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipUnless(two_questions_found, "no two questions auth plugin")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(not two_questions_found, reason="no two questions auth plugin")
def testDialogAuthTwoQuestions(self):
self.realTestDialogAuthTwoQuestions()
def realTestDialogAuthTwoQuestions(self):
- TestAuthentication.Dialog.fail=False
- TestAuthentication.Dialog.m = {b'Password, please:': b'notverysecret',
- b'Are you sure ?': b'yes, of course'}
- with TempUser(self.connections[0].cursor(), 'pymysql_2q@localhost',
- self.databases[0]['db'], 'two_questions', 'notverysecret') as u:
+ TestAuthentication.Dialog.fail = False
+ TestAuthentication.Dialog.m = {
+ b"Password, please:": b"notverysecret",
+ b"Are you sure ?": b"yes, of course",
+ }
+ with TempUser(
+ self.connect().cursor(),
+ "pymysql_2q@localhost",
+ self.databases[0]["database"],
+ "two_questions",
+ "notverysecret",
+ ):
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_2q', **self.db)
- pymysql.connect(user='pymysql_2q', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
-
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipIf(three_attempts_found, "three_attempts plugin already installed")
+ pymysql.connect(user="pymysql_2q", **self.db)
+ pymysql.connect(
+ user="pymysql_2q",
+ auth_plugin_map={b"dialog": TestAuthentication.Dialog},
+ **self.db,
+ )
+
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(
+ three_attempts_found, reason="three_attempts plugin already installed"
+ )
def testDialogAuthThreeAttemptsQuestionsInstallPlugin(self):
# needs plugin. lets install it.
- cur = self.connections[0].cursor()
+ cur = self.connect().cursor()
try:
cur.execute("install plugin three_attempts soname 'dialog_examples.so'")
TestAuthentication.three_attempts_found = True
self.realTestDialogAuthThreeAttempts()
except pymysql.err.InternalError:
- raise unittest2.SkipTest('we couldn\'t install the three_attempts plugin')
+ pytest.skip("we couldn't install the three_attempts plugin")
finally:
if TestAuthentication.three_attempts_found:
cur.execute("uninstall plugin three_attempts")
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipUnless(three_attempts_found, "no three attempts plugin")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(not three_attempts_found, reason="no three attempts plugin")
def testDialogAuthThreeAttempts(self):
self.realTestDialogAuthThreeAttempts()
def realTestDialogAuthThreeAttempts(self):
- TestAuthentication.Dialog.m = {b'Password, please:': b'stillnotverysecret'}
- TestAuthentication.Dialog.fail=True # fail just once. We've got three attempts after all
- with TempUser(self.connections[0].cursor(), 'pymysql_3a@localhost',
- self.databases[0]['db'], 'three_attempts', 'stillnotverysecret') as u:
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DialogHandler}, **self.db)
+ TestAuthentication.Dialog.m = {b"Password, please:": b"stillnotverysecret"}
+ TestAuthentication.Dialog.fail = (
+ True # fail just once. We've got three attempts after all
+ )
+ with TempUser(
+ self.connect().cursor(),
+ "pymysql_3a@localhost",
+ self.databases[0]["database"],
+ "three_attempts",
+ "stillnotverysecret",
+ ):
+ pymysql.connect(
+ user="pymysql_3a",
+ auth_plugin_map={b"dialog": TestAuthentication.Dialog},
+ **self.db,
+ )
+ pymysql.connect(
+ user="pymysql_3a",
+ auth_plugin_map={b"dialog": TestAuthentication.DialogHandler},
+ **self.db,
+ )
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': object}, **self.db)
+ pymysql.connect(
+ user="pymysql_3a", auth_plugin_map={b"dialog": object}, **self.db
+ )
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.DefectiveHandler}, **self.db)
+ pymysql.connect(
+ user="pymysql_3a",
+ auth_plugin_map={b"dialog": TestAuthentication.DefectiveHandler},
+ **self.db,
+ )
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'notdialogplugin': TestAuthentication.Dialog}, **self.db)
- TestAuthentication.Dialog.m = {b'Password, please:': b'I do not know'}
+ pymysql.connect(
+ user="pymysql_3a",
+ auth_plugin_map={b"notdialogplugin": TestAuthentication.Dialog},
+ **self.db,
+ )
+ TestAuthentication.Dialog.m = {b"Password, please:": b"I do not know"}
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
- TestAuthentication.Dialog.m = {b'Password, please:': None}
+ pymysql.connect(
+ user="pymysql_3a",
+ auth_plugin_map={b"dialog": TestAuthentication.Dialog},
+ **self.db,
+ )
+ TestAuthentication.Dialog.m = {b"Password, please:": None}
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_3a', auth_plugin_map={b'dialog': TestAuthentication.Dialog}, **self.db)
-
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipIf(pam_found, "pam plugin already installed")
- @unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required")
- @unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required")
+ pymysql.connect(
+ user="pymysql_3a",
+ auth_plugin_map={b"dialog": TestAuthentication.Dialog},
+ **self.db,
+ )
+
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(pam_found, reason="pam plugin already installed")
+ @pytest.mark.skipif(
+ os.environ.get("PASSWORD") is None, reason="PASSWORD env var required"
+ )
+ @pytest.mark.skipif(
+ os.environ.get("PAMSERVICE") is None, reason="PAMSERVICE env var required"
+ )
def testPamAuthInstallPlugin(self):
# needs plugin. lets install it.
- cur = self.connections[0].cursor()
+ cur = self.connect().cursor()
try:
cur.execute("install plugin pam soname 'auth_pam.so'")
TestAuthentication.pam_found = True
self.realTestPamAuth()
except pymysql.err.InternalError:
- raise unittest2.SkipTest('we couldn\'t install the auth_pam plugin')
+ pytest.skip("we couldn't install the auth_pam plugin")
finally:
if TestAuthentication.pam_found:
cur.execute("uninstall plugin pam")
-
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipUnless(pam_found, "no pam plugin")
- @unittest2.skipIf(os.environ.get('PASSWORD') is None, "PASSWORD env var required")
- @unittest2.skipIf(os.environ.get('PAMSERVICE') is None, "PAMSERVICE env var required")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(not pam_found, reason="no pam plugin")
+ @pytest.mark.skipif(
+ os.environ.get("PASSWORD") is None, reason="PASSWORD env var required"
+ )
+ @pytest.mark.skipif(
+ os.environ.get("PAMSERVICE") is None, reason="PAMSERVICE env var required"
+ )
def testPamAuth(self):
self.realTestPamAuth()
def realTestPamAuth(self):
db = self.db.copy()
import os
- db['password'] = os.environ.get('PASSWORD')
- cur = self.connections[0].cursor()
+
+ db["password"] = os.environ.get("PASSWORD")
+ cur = self.connect().cursor()
try:
- cur.execute('show grants for ' + TestAuthentication.osuser + '@localhost')
+ cur.execute("show grants for " + TestAuthentication.osuser + "@localhost")
grants = cur.fetchone()[0]
- cur.execute('drop user ' + TestAuthentication.osuser + '@localhost')
+ cur.execute("drop user " + TestAuthentication.osuser + "@localhost")
except pymysql.OperationalError as e:
# assuming the user doesn't exist which is ok too
self.assertEqual(1045, e.args[0])
grants = None
- with TempUser(cur, TestAuthentication.osuser + '@localhost',
- self.databases[0]['db'], 'pam', os.environ.get('PAMSERVICE')) as u:
+ with TempUser(
+ cur,
+ TestAuthentication.osuser + "@localhost",
+ self.databases[0]["database"],
+ "pam",
+ os.environ.get("PAMSERVICE"),
+ ):
try:
- c = pymysql.connect(user=TestAuthentication.osuser, **db)
- db['password'] = 'very bad guess at password'
+ pymysql.connect(user=TestAuthentication.osuser, **db)
+ db["password"] = "very bad guess at password"
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user=TestAuthentication.osuser,
- auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
- **self.db)
+ pymysql.connect(
+ user=TestAuthentication.osuser,
+ auth_plugin_map={
+ b"mysql_cleartext_password": TestAuthentication.DefectiveHandler
+ },
+ **self.db,
+ )
except pymysql.OperationalError as e:
self.assertEqual(1045, e.args[0])
- # we had 'bad guess at password' work with pam. Well at least we get a permission denied here
+ # we had 'bad guess at password' work with pam. Well at least we get
+ # a permission denied here
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user=TestAuthentication.osuser,
- auth_plugin_map={b'mysql_cleartext_password': TestAuthentication.DefectiveHandler},
- **self.db)
+ pymysql.connect(
+ user=TestAuthentication.osuser,
+ auth_plugin_map={
+ b"mysql_cleartext_password": TestAuthentication.DefectiveHandler
+ },
+ **self.db,
+ )
if grants:
# recreate the user
cur.execute(grants)
- # select old_password("crummy p\tassword");
- #| old_password("crummy p\tassword") |
- #| 2a01785203b08770 |
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipUnless(mysql_old_password_found, "no mysql_old_password plugin")
- def testMySQLOldPasswordAuth(self):
- if self.mysql_server_is(self.connections[0], (5, 7, 0)):
- raise unittest2.SkipTest('Old passwords aren\'t supported in 5.7')
- # pymysql.err.OperationalError: (1045, "Access denied for user 'old_pass_user'@'localhost' (using password: YES)")
- # from login in MySQL-5.6
- if self.mysql_server_is(self.connections[0], (5, 6, 0)):
- raise unittest2.SkipTest('Old passwords don\'t authenticate in 5.6')
- db = self.db.copy()
- db['password'] = "crummy p\tassword"
- with self.connections[0] as c:
- # deprecated in 5.6
- if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)):
- with self.assertWarns(pymysql.err.Warning) as cm:
- c.execute("SELECT OLD_PASSWORD('%s')" % db['password'])
- else:
- c.execute("SELECT OLD_PASSWORD('%s')" % db['password'])
- v = c.fetchone()[0]
- self.assertEqual(v, '2a01785203b08770')
- # only works in MariaDB and MySQL-5.6 - can't separate out by version
- #if self.mysql_server_is(self.connections[0], (5, 5, 0)):
- # with TempUser(c, 'old_pass_user@localhost',
- # self.databases[0]['db'], 'mysql_old_password', '2a01785203b08770') as u:
- # cur = pymysql.connect(user='old_pass_user', **db).cursor()
- # cur.execute("SELECT VERSION()")
- c.execute("SELECT @@secure_auth")
- secure_auth_setting = c.fetchone()[0]
- c.execute('set old_passwords=1')
- # pymysql.err.Warning: 'pre-4.1 password hash' is deprecated and will be removed in a future release. Please use post-4.1 password hash instead
- if sys.version_info[0:2] >= (3,2) and self.mysql_server_is(self.connections[0], (5, 6, 0)):
- with self.assertWarns(pymysql.err.Warning) as cm:
- c.execute('set global secure_auth=0')
- else:
- c.execute('set global secure_auth=0')
- with TempUser(c, 'old_pass_user@localhost',
- self.databases[0]['db'], password=db['password']) as u:
- cur = pymysql.connect(user='old_pass_user', **db).cursor()
- cur.execute("SELECT VERSION()")
- c.execute('set global secure_auth=%r' % secure_auth_setting)
-
- @unittest2.skipUnless(socket_auth, "connection to unix_socket required")
- @unittest2.skipUnless(sha256_password_found, "no sha256 password authentication plugin found")
+ @pytest.mark.skipif(not socket_auth, reason="connection to unix_socket required")
+ @pytest.mark.skipif(
+ not sha256_password_found,
+ reason="no sha256 password authentication plugin found",
+ )
def testAuthSHA256(self):
- c = self.connections[0].cursor()
- with TempUser(c, 'pymysql_sha256@localhost',
- self.databases[0]['db'], 'sha256_password') as u:
- if self.mysql_server_is(self.connections[0], (5, 7, 0)):
- c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'")
- else:
- c.execute('SET old_passwords = 2')
- c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' = PASSWORD('Sh@256Pa33')")
+ conn = self.connect()
+ c = conn.cursor()
+ with TempUser(
+ c,
+ "pymysql_sha256@localhost",
+ self.databases[0]["database"],
+ "sha256_password",
+ ):
+ c.execute("SET PASSWORD FOR 'pymysql_sha256'@'localhost' ='Sh@256Pa33'")
c.execute("FLUSH PRIVILEGES")
db = self.db.copy()
- db['password'] = "Sh@256Pa33"
- # Although SHA256 is supported, need the configuration of public key of the mysql server. Currently will get error by this test.
+ db["password"] = "Sh@256Pa33"
+ # Although SHA256 is supported, need the configuration of public key of
+ # the mysql server. Currently will get error by this test.
with self.assertRaises(pymysql.err.OperationalError):
- pymysql.connect(user='pymysql_sha256', **db)
+ pymysql.connect(user="pymysql_sha256", **db)
+
+ @pytest.mark.skipif(not ed25519_found, reason="no ed25519 authention plugin")
+ def testAuthEd25519(self):
+ db = self.db.copy()
+ del db["password"]
+ conn = self.connect()
+ c = conn.cursor()
+ c.execute("select ed25519_password(''), ed25519_password('ed25519_password')")
+ for r in c:
+ empty_pass = r[0].decode("ascii")
+ non_empty_pass = r[1].decode("ascii")
+
+ with TempUser(
+ c,
+ "pymysql_ed25519",
+ self.databases[0]["database"],
+ "ed25519",
+ empty_pass,
+ ):
+ pymysql.connect(user="pymysql_ed25519", password="", **db)
+
+ with TempUser(
+ c,
+ "pymysql_ed25519",
+ self.databases[0]["database"],
+ "ed25519",
+ non_empty_pass,
+ ):
+ pymysql.connect(user="pymysql_ed25519", password="ed25519_password", **db)
-class TestConnection(base.PyMySQLTestCase):
+class TestConnection(base.PyMySQLTestCase):
def test_utf8mb4(self):
"""This test requires MySQL >= 5.5"""
arg = self.databases[0].copy()
- arg['charset'] = 'utf8mb4'
- conn = pymysql.connect(**arg)
+ arg["charset"] = "utf8mb4"
+ pymysql.connect(**arg)
+
+ def test_set_character_set(self):
+ con = self.connect()
+ cur = con.cursor()
+
+ con.set_character_set("latin1")
+ cur.execute("SELECT @@character_set_connection")
+ self.assertEqual(cur.fetchone(), ("latin1",))
+ self.assertEqual(con.encoding, "cp1252")
+
+ con.set_character_set("utf8mb4", "utf8mb4_general_ci")
+ cur.execute("SELECT @@character_set_connection, @@collation_connection")
+ self.assertEqual(cur.fetchone(), ("utf8mb4", "utf8mb4_general_ci"))
+ self.assertEqual(con.encoding, "utf8")
def test_largedata(self):
"""Large query and response (>=16MB)"""
- cur = self.connections[0].cursor()
+ cur = self.connect().cursor()
cur.execute("SELECT @@max_allowed_packet")
- if cur.fetchone()[0] < 16*1024*1024 + 10:
+ if cur.fetchone()[0] < 16 * 1024 * 1024 + 10:
print("Set max_allowed_packet to bigger than 17MB")
return
- t = 'a' * (16*1024*1024)
+ t = "a" * (16 * 1024 * 1024)
cur.execute("SELECT '" + t + "'")
assert cur.fetchone()[0] == t
def test_autocommit(self):
- con = self.connections[0]
+ con = self.connect()
self.assertFalse(con.get_autocommit())
cur = con.cursor()
@@ -400,16 +483,16 @@ def test_autocommit(self):
self.assertEqual(cur.fetchone()[0], 0)
def test_select_db(self):
- con = self.connections[0]
- current_db = self.databases[0]['db']
- other_db = self.databases[1]['db']
+ con = self.connect()
+ current_db = self.databases[0]["database"]
+ other_db = self.databases[1]["database"]
cur = con.cursor()
- cur.execute('SELECT database()')
+ cur.execute("SELECT database()")
self.assertEqual(cur.fetchone()[0], current_db)
con.select_db(other_db)
- cur.execute('SELECT database()')
+ cur.execute("SELECT database()")
self.assertEqual(cur.fetchone()[0], other_db)
def test_connection_gone_away(self):
@@ -423,49 +506,31 @@ def test_connection_gone_away(self):
time.sleep(2)
with self.assertRaises(pymysql.OperationalError) as cm:
cur.execute("SELECT 1+1")
- # error occures while reading, not writing because of socket buffer.
- #self.assertEqual(cm.exception.args[0], 2006)
+ # error occurs while reading, not writing because of socket buffer.
+ # self.assertEqual(cm.exception.args[0], 2006)
self.assertIn(cm.exception.args[0], (2006, 2013))
def test_init_command(self):
conn = self.connect(
init_command='SELECT "bar"; SELECT "baz"',
- client_flag=CLIENT.MULTI_STATEMENTS)
+ client_flag=CLIENT.MULTI_STATEMENTS,
+ )
c = conn.cursor()
c.execute('select "foobar";')
- self.assertEqual(('foobar',), c.fetchone())
+ self.assertEqual(("foobar",), c.fetchone())
conn.close()
with self.assertRaises(pymysql.err.Error):
conn.ping(reconnect=False)
def test_read_default_group(self):
conn = self.connect(
- read_default_group='client',
+ read_default_group="client",
)
self.assertTrue(conn.open)
- def test_context(self):
- with self.assertRaises(ValueError):
- c = self.connect()
- with c as cur:
- cur.execute('create table test ( a int ) ENGINE=InnoDB')
- c.begin()
- cur.execute('insert into test values ((1))')
- raise ValueError('pseudo abort')
- c.commit()
- c = self.connect()
- with c as cur:
- cur.execute('select count(*) from test')
- self.assertEqual(0, cur.fetchone()[0])
- cur.execute('insert into test values ((1))')
- with c as cur:
- cur.execute('select count(*) from test')
- self.assertEqual(1,cur.fetchone()[0])
- cur.execute('drop table test')
-
def test_set_charset(self):
c = self.connect()
- c.set_charset('utf8mb4')
+ c.set_charset("utf8mb4")
# TODO validate setting here
def test_defer_connect(self):
@@ -474,12 +539,13 @@ def test_defer_connect(self):
d = self.databases[0].copy()
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- sock.connect(d['unix_socket'])
+ sock.connect(d["unix_socket"])
except KeyError:
sock.close()
sock = socket.create_connection(
- (d.get('host', 'localhost'), d.get('port', 3306)))
- for k in ['unix_socket', 'host', 'port']:
+ (d.get("host", "localhost"), d.get("port", 3306))
+ )
+ for k in ["unix_socket", "host", "port"]:
try:
del d[k]
except KeyError:
@@ -491,9 +557,250 @@ def test_defer_connect(self):
c.close()
sock.close()
+ def test_ssl_connect(self):
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl={
+ "ca": "ca",
+ "cert": "cert",
+ "key": "key",
+ "cipher": "cipher",
+ },
+ )
+ assert create_default_context.called
+ assert dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_called_with("cipher")
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl={
+ "ca": "ca",
+ "cert": "cert",
+ "key": "key",
+ },
+ )
+ assert create_default_context.called
+ assert dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl={
+ "ca": "ca",
+ "cert": "cert",
+ "key": "key",
+ "password": "password",
+ },
+ )
+ assert create_default_context.called
+ assert dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password="password",
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_ca="ca",
+ )
+ assert create_default_context.called
+ assert not dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
+ dummy_ssl_context.load_cert_chain.assert_not_called
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_ca="ca",
+ ssl_cert="cert",
+ ssl_key="key",
+ )
+ assert create_default_context.called
+ assert not dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ for ssl_verify_cert in (True, "1", "yes", "true"):
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_cert="cert",
+ ssl_key="key",
+ ssl_verify_cert=ssl_verify_cert,
+ )
+ assert create_default_context.called
+ assert not dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_REQUIRED
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ for ssl_verify_cert in (None, False, "0", "no", "false"):
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_cert="cert",
+ ssl_key="key",
+ ssl_verify_cert=ssl_verify_cert,
+ )
+ assert create_default_context.called
+ assert not dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ for ssl_ca in ("ca", None):
+ for ssl_verify_cert in ("foo", "bar", ""):
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_ca=ssl_ca,
+ ssl_cert="cert",
+ ssl_key="key",
+ ssl_verify_cert=ssl_verify_cert,
+ )
+ assert create_default_context.called
+ assert not dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == (
+ ssl.CERT_REQUIRED if ssl_ca is not None else ssl.CERT_NONE
+ ), (ssl_ca, ssl_verify_cert)
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_ca="ca",
+ ssl_cert="cert",
+ ssl_key="key",
+ ssl_verify_identity=True,
+ )
+ assert create_default_context.called
+ assert dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password=None,
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_ca="ca",
+ ssl_cert="cert",
+ ssl_key="key",
+ ssl_key_password="password",
+ ssl_verify_identity=True,
+ )
+ assert create_default_context.called
+ assert dummy_ssl_context.check_hostname
+ assert dummy_ssl_context.verify_mode == ssl.CERT_NONE
+ dummy_ssl_context.load_cert_chain.assert_called_with(
+ "cert",
+ keyfile="key",
+ password="password",
+ )
+ dummy_ssl_context.set_ciphers.assert_not_called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_disabled=True,
+ ssl={
+ "ca": "ca",
+ "cert": "cert",
+ "key": "key",
+ },
+ )
+ assert not create_default_context.called
+
+ dummy_ssl_context = mock.Mock(options=0)
+ with mock.patch("pymysql.connections.Connection.connect"), mock.patch(
+ "pymysql.connections.ssl.create_default_context",
+ new=mock.Mock(return_value=dummy_ssl_context),
+ ) as create_default_context:
+ pymysql.connect(
+ ssl_disabled=True,
+ ssl_ca="ca",
+ ssl_cert="cert",
+ ssl_key="key",
+ )
+ assert not create_default_context.called
+
# A custom type and function to escape it
-class Foo(object):
+class Foo:
value = "bar"
@@ -503,7 +810,7 @@ def escape_foo(x, d):
class TestEscape(base.PyMySQLTestCase):
def test_escape_string(self):
- con = self.connections[0]
+ con = self.connect()
cur = con.cursor()
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
@@ -516,46 +823,43 @@ def test_escape_string(self):
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
def test_escape_builtin_encoders(self):
- con = self.connections[0]
- cur = con.cursor()
+ con = self.connect()
val = datetime.datetime(2012, 3, 4, 5, 6)
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
def test_escape_custom_object(self):
- con = self.connections[0]
- cur = con.cursor()
+ con = self.connect()
mapping = {Foo: escape_foo}
self.assertEqual(con.escape(Foo(), mapping), "bar")
def test_escape_fallback_encoder(self):
- con = self.connections[0]
- cur = con.cursor()
+ con = self.connect()
class Custom(str):
pass
- mapping = {text_type: pymysql.escape_string}
- self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
+ mapping = {str: pymysql.converters.escape_string}
+ self.assertEqual(con.escape(Custom("foobar"), mapping), "'foobar'")
def test_escape_no_default(self):
- con = self.connections[0]
- cur = con.cursor()
+ con = self.connect()
self.assertRaises(TypeError, con.escape, 42, {})
- def test_escape_dict_value(self):
- con = self.connections[0]
- cur = con.cursor()
+ def test_escape_dict_raise_typeerror(self):
+ """con.escape(dict) should raise TypeError"""
+ con = self.connect()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
- self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
+ #self.assertEqual(con.escape({"foo": Foo()}, mapping), {"foo": "bar"})
+ with self.assertRaises(TypeError):
+ con.escape({"foo": Foo()})
def test_escape_list_item(self):
- con = self.connections[0]
- cur = con.cursor()
+ con = self.connect()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
@@ -564,7 +868,8 @@ def test_escape_list_item(self):
def test_previous_cursor_not_closed(self):
con = self.connect(
init_command='SELECT "bar"; SELECT "baz"',
- client_flag=CLIENT.MULTI_STATEMENTS)
+ client_flag=CLIENT.MULTI_STATEMENTS,
+ )
cur1 = con.cursor()
cur1.execute("SELECT 1; SELECT 2")
cur2 = con.cursor()
diff --git a/pymysql/tests/test_converters.py b/pymysql/tests/test_converters.py
index b7b5a9846..b36ee4b39 100644
--- a/pymysql/tests/test_converters.py
+++ b/pymysql/tests/test_converters.py
@@ -1,7 +1,5 @@
import datetime
from unittest import TestCase
-
-from pymysql._compat import PY2
from pymysql import converters
@@ -9,41 +7,30 @@
class TestConverter(TestCase):
-
def test_escape_string(self):
- self.assertEqual(
- converters.escape_string(u"foo\nbar"),
- u"foo\\nbar"
- )
-
- if PY2:
- def test_escape_string_bytes(self):
- self.assertEqual(
- converters.escape_string(b"foo\nbar"),
- b"foo\\nbar"
- )
+ self.assertEqual(converters.escape_string("foo\nbar"), "foo\\nbar")
def test_convert_datetime(self):
expected = datetime.datetime(2007, 2, 24, 23, 6, 20)
- dt = converters.convert_datetime('2007-02-24 23:06:20')
+ dt = converters.convert_datetime("2007-02-24 23:06:20")
self.assertEqual(dt, expected)
def test_convert_datetime_with_fsp(self):
expected = datetime.datetime(2007, 2, 24, 23, 6, 20, 511581)
- dt = converters.convert_datetime('2007-02-24 23:06:20.511581')
+ dt = converters.convert_datetime("2007-02-24 23:06:20.511581")
self.assertEqual(dt, expected)
def _test_convert_timedelta(self, with_negate=False, with_fsp=False):
- d = {'hours': 789, 'minutes': 12, 'seconds': 34}
- s = '%(hours)s:%(minutes)s:%(seconds)s' % d
+ d = {"hours": 789, "minutes": 12, "seconds": 34}
+ s = "%(hours)s:%(minutes)s:%(seconds)s" % d
if with_fsp:
- d['microseconds'] = 511581
- s += '.%(microseconds)s' % d
+ d["microseconds"] = 511581
+ s += ".%(microseconds)s" % d
expected = datetime.timedelta(**d)
if with_negate:
expected = -expected
- s = '-' + s
+ s = "-" + s
tdelta = converters.convert_timedelta(s)
self.assertEqual(tdelta, expected)
@@ -58,10 +45,10 @@ def test_convert_timedelta_with_fsp(self):
def test_convert_time(self):
expected = datetime.time(23, 6, 20)
- time_obj = converters.convert_time('23:06:20')
+ time_obj = converters.convert_time("23:06:20")
self.assertEqual(time_obj, expected)
def test_convert_time_with_fsp(self):
expected = datetime.time(23, 6, 20, 511581)
- time_obj = converters.convert_time('23:06:20.511581')
+ time_obj = converters.convert_time("23:06:20.511581")
self.assertEqual(time_obj, expected)
diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py
index add047550..2e267fb6a 100644
--- a/pymysql/tests/test_cursor.py
+++ b/pymysql/tests/test_cursor.py
@@ -1,25 +1,37 @@
-import warnings
-
+from pymysql.constants import ER
from pymysql.tests import base
import pymysql.cursors
+import pytest
+
+
class CursorTest(base.PyMySQLTestCase):
def setUp(self):
- super(CursorTest, self).setUp()
+ super().setUp()
- conn = self.connections[0]
+ conn = self.connect()
self.safe_create_table(
conn,
- "test", "create table test (data varchar(10))",
+ "test",
+ "create table test (data varchar(10))",
)
cursor = conn.cursor()
cursor.execute(
- "insert into test (data) values "
- "('row1'), ('row2'), ('row3'), ('row4'), ('row5')")
+ "insert into test (data) values ('row1'), ('row2'), ('row3'), ('row4'), ('row5')"
+ )
+ conn.commit()
cursor.close()
self.test_connection = pymysql.connect(**self.databases[0])
self.addCleanup(self.test_connection.close)
+ def test_cursor_is_iterator(self):
+ """Test that the cursor is an iterator"""
+ conn = self.test_connection
+ cursor = conn.cursor()
+ cursor.execute("select * from test")
+ self.assertEqual(cursor.__iter__(), cursor)
+ self.assertEqual(cursor.__next__(), ("row1",))
+
def test_cleanup_rows_unbuffered(self):
conn = self.test_connection
cursor = conn.cursor(pymysql.cursors.SSCursor)
@@ -30,7 +42,6 @@ def test_cleanup_rows_unbuffered(self):
break
del cursor
- self.safe_gc_collect()
c2 = conn.cursor()
@@ -48,61 +59,164 @@ def test_cleanup_rows_buffered(self):
break
del cursor
- self.safe_gc_collect()
c2 = conn.cursor()
-
c2.execute("select 1")
- self.assertEqual(
- c2.fetchone(), (1,)
- )
+ self.assertEqual(c2.fetchone(), (1,))
self.assertIsNone(c2.fetchone())
def test_executemany(self):
conn = self.test_connection
cursor = conn.cursor(pymysql.cursors.Cursor)
- m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)")
- self.assertIsNotNone(m, 'error parse %s')
- self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?')
+ m = pymysql.cursors.RE_INSERT_VALUES.match(
+ "INSERT INTO TEST (ID, NAME) VALUES (%s, %s)"
+ )
+ self.assertIsNotNone(m, "error parse %s")
+ self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?")
- m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)")
- self.assertIsNotNone(m, 'error parse %(name)s')
- self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?')
+ m = pymysql.cursors.RE_INSERT_VALUES.match(
+ "INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)"
+ )
+ self.assertIsNotNone(m, "error parse %(name)s")
+ self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?")
- m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)")
- self.assertIsNotNone(m, 'error parse %(id_name)s')
- self.assertEqual(m.group(3), '', 'group 3 not blank, bug in RE_INSERT_VALUES?')
+ m = pymysql.cursors.RE_INSERT_VALUES.match(
+ "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)"
+ )
+ self.assertIsNotNone(m, "error parse %(id_name)s")
+ self.assertEqual(m.group(3), "", "group 3 not blank, bug in RE_INSERT_VALUES?")
- m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update")
- self.assertIsNotNone(m, 'error parse %(id_name)s')
- self.assertEqual(m.group(3), ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?')
+ m = pymysql.cursors.RE_INSERT_VALUES.match(
+ "INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update"
+ )
+ self.assertIsNotNone(m, "error parse %(id_name)s")
+ self.assertEqual(
+ m.group(3),
+ " ON duplicate update",
+ "group 3 not ON duplicate update, bug in RE_INSERT_VALUES?",
+ )
# https://github.com/PyMySQL/PyMySQL/pull/597
- m = pymysql.cursors.RE_INSERT_VALUES.match("INSERT INTO bloup(foo, bar)VALUES(%s, %s)")
+ m = pymysql.cursors.RE_INSERT_VALUES.match(
+ "INSERT INTO bloup(foo, bar)VALUES(%s, %s)"
+ )
assert m is not None
- # cursor._executed must bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
+ # cursor._executed must bee "insert into test (data)
+ # values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
# list args
data = range(10)
cursor.executemany("insert into test (data) values (%s)", data)
- self.assertTrue(cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query')
+ self.assertTrue(
+ cursor._executed.endswith(b",(7),(8),(9)"),
+ "execute many with %s not in one query",
+ )
# dict args
- data_dict = [{'data': i} for i in range(10)]
+ data_dict = [{"data": i} for i in range(10)]
cursor.executemany("insert into test (data) values (%(data)s)", data_dict)
- self.assertTrue(cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query')
+ self.assertTrue(
+ cursor._executed.endswith(b",(7),(8),(9)"),
+ "execute many with %(data)s not in one query",
+ )
# %% in column set
- cursor.execute("""\
+ cursor.execute(
+ """\
CREATE TABLE percent_test (
`A%` INTEGER,
- `B%` INTEGER)""")
+ `B%` INTEGER)"""
+ )
try:
q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)"
self.assertIsNotNone(pymysql.cursors.RE_INSERT_VALUES.match(q))
cursor.executemany(q, [(3, 4), (5, 6)])
- self.assertTrue(cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query")
+ self.assertTrue(
+ cursor._executed.endswith(b"(3, 4),(5, 6)"),
+ "executemany with %% not in one query",
+ )
finally:
cursor.execute("DROP TABLE IF EXISTS percent_test")
+
+ def test_execution_time_limit(self):
+ # this method is similarly implemented in test_SScursor
+
+ conn = self.test_connection
+ db_type = self.get_mysql_vendor(conn)
+
+ with conn.cursor(pymysql.cursors.Cursor) as cur:
+ # MySQL MAX_EXECUTION_TIME takes ms
+ # MariaDB max_statement_time takes seconds as int/float, introduced in 10.1
+
+ # this will sleep 0.01 seconds per row
+ if db_type == "mysql":
+ sql = (
+ "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test"
+ )
+ else:
+ sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test"
+
+ cur.execute(sql)
+ # unlike SSCursor, Cursor returns a tuple of tuples here
+ self.assertEqual(
+ cur.fetchall(),
+ (
+ ("row1", 0),
+ ("row2", 0),
+ ("row3", 0),
+ ("row4", 0),
+ ("row5", 0),
+ ),
+ )
+
+ if db_type == "mysql":
+ sql = (
+ "SELECT /*+ MAX_EXECUTION_TIME(2000) */ data, sleep(0.01) FROM test"
+ )
+ else:
+ sql = "SET STATEMENT max_statement_time=2 FOR SELECT data, sleep(0.01) FROM test"
+ cur.execute(sql)
+ self.assertEqual(cur.fetchone(), ("row1", 0))
+
+ # this discards the previous unfinished query
+ cur.execute("SELECT 1")
+ self.assertEqual(cur.fetchone(), (1,))
+
+ if db_type == "mysql":
+ sql = "SELECT /*+ MAX_EXECUTION_TIME(1) */ data, sleep(1) FROM test"
+ else:
+ sql = "SET STATEMENT max_statement_time=0.001 FOR SELECT data, sleep(1) FROM test"
+ with pytest.raises(pymysql.err.OperationalError) as cm:
+ # in a buffered cursor this should reliably raise an
+ # OperationalError
+ cur.execute(sql)
+
+ if db_type == "mysql":
+ # this constant was only introduced in MySQL 5.7, not sure
+ # what was returned before, may have been ER_QUERY_INTERRUPTED
+ self.assertEqual(cm.value.args[0], ER.QUERY_TIMEOUT)
+ else:
+ self.assertEqual(cm.value.args[0], ER.STATEMENT_TIMEOUT)
+
+ # connection should still be fine at this point
+ cur.execute("SELECT 1")
+ self.assertEqual(cur.fetchone(), (1,))
+
+ def test_warnings(self):
+ con = self.connect()
+ cur = con.cursor()
+ cur.execute("DROP TABLE IF EXISTS `no_exists_table`")
+ self.assertEqual(cur.warning_count, 1)
+
+ cur.execute("SHOW WARNINGS")
+ w = cur.fetchone()
+ self.assertEqual(w[1], ER.BAD_TABLE_ERROR)
+ self.assertIn(
+ "no_exists_table",
+ w[2],
+ )
+
+ cur.execute("SELECT 1")
+ self.assertEqual(cur.warning_count, 0)
diff --git a/pymysql/tests/test_err.py b/pymysql/tests/test_err.py
index 3468d1b10..6eb0f987d 100644
--- a/pymysql/tests/test_err.py
+++ b/pymysql/tests/test_err.py
@@ -1,21 +1,16 @@
-import unittest2
-
+import pytest
from pymysql import err
-__all__ = ["TestRaiseException"]
-
-
-class TestRaiseException(unittest2.TestCase):
-
- def test_raise_mysql_exception(self):
- data = b"\xff\x15\x04Access denied"
- with self.assertRaises(err.OperationalError) as cm:
- err.raise_mysql_exception(data)
- self.assertEqual(cm.exception.args, (1045, 'Access denied'))
+def test_raise_mysql_exception():
+ data = b"\xff\x15\x04#28000Access denied"
+ with pytest.raises(err.OperationalError) as cm:
+ err.raise_mysql_exception(data)
+ assert cm.type == err.OperationalError
+ assert cm.value.args == (1045, "Access denied")
- def test_raise_mysql_exception_client_protocol_41(self):
- data = b"\xff\x15\x04#28000Access denied"
- with self.assertRaises(err.OperationalError) as cm:
- err.raise_mysql_exception(data)
- self.assertEqual(cm.exception.args, (1045, 'Access denied'))
+ data = b"\xff\x10\x04Too many connections"
+ with pytest.raises(err.OperationalError) as cm:
+ err.raise_mysql_exception(data)
+ assert cm.type == err.OperationalError
+ assert cm.value.args == (1040, "Too many connections")
diff --git a/pymysql/tests/test_issues.py b/pymysql/tests/test_issues.py
index cedd09258..f1fe8dd48 100644
--- a/pymysql/tests/test_issues.py
+++ b/pymysql/tests/test_issues.py
@@ -1,34 +1,29 @@
import datetime
import time
import warnings
-import sys
+
+import pytest
import pymysql
-from pymysql import cursors
-from pymysql._compat import text_type
from pymysql.tests import base
-import unittest2
-
-try:
- import imp
- reload = imp.reload
-except AttributeError:
- pass
-
__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"]
+
class TestOldIssues(base.PyMySQLTestCase):
def test_issue_3(self):
- """ undefined methods datetime_or_None, date_or_None """
- conn = self.connections[0]
+ """undefined methods datetime_or_None, date_or_None"""
+ conn = self.connect()
c = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue3")
c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
try:
- c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
+ c.execute(
+ "insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)",
+ (None, None, None, None),
+ )
c.execute("select d from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select t from issue3")
@@ -36,13 +31,17 @@ def test_issue_3(self):
c.execute("select dt from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select ts from issue3")
- self.assertIn(type(c.fetchone()[0]), (type(None), datetime.datetime), 'expected Python type None or datetime from SQL timestamp')
+ self.assertIn(
+ type(c.fetchone()[0]),
+ (type(None), datetime.datetime),
+ "expected Python type None or datetime from SQL timestamp",
+ )
finally:
c.execute("drop table issue3")
def test_issue_4(self):
- """ can't retrieve TIMESTAMP fields """
- conn = self.connections[0]
+ """can't retrieve TIMESTAMP fields"""
+ conn = self.connect()
c = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
@@ -56,32 +55,34 @@ def test_issue_4(self):
c.execute("drop table issue4")
def test_issue_5(self):
- """ query on information_schema.tables fails """
- con = self.connections[0]
+ """query on information_schema.tables fails"""
+ con = self.connect()
cur = con.cursor()
cur.execute("select * from information_schema.tables")
def test_issue_6(self):
- """ exception: TypeError: ord() expected a character, but string of length 0 found """
+ """exception: TypeError: ord() expected a character, but string of length 0 found"""
# ToDo: this test requires access to db 'mysql'.
kwargs = self.databases[0].copy()
- kwargs['db'] = "mysql"
+ kwargs["database"] = "mysql"
conn = pymysql.connect(**kwargs)
c = conn.cursor()
c.execute("select * from user")
conn.close()
def test_issue_8(self):
- """ Primary Key and Index error when selecting data """
- conn = self.connections[0]
+ """Primary Key and Index error when selecting data"""
+ conn = self.connect()
c = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists test")
- c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
-datetime NOT NULL DEFAULT '2015-01-01 00:00:00', `echeance` int(1) NOT NULL
+ c.execute(
+ """CREATE TABLE `test` (`station` int NOT NULL DEFAULT '0', `dh`
+datetime NOT NULL DEFAULT '2015-01-01 00:00:00', `echeance` int NOT NULL
DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
-KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
+KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;"""
+ )
try:
self.assertEqual(0, c.execute("SELECT * FROM test"))
c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)")
@@ -89,16 +90,9 @@ def test_issue_8(self):
finally:
c.execute("drop table test")
- def test_issue_9(self):
- """ sets DeprecationWarning in Python 2.6 """
- try:
- reload(pymysql)
- except DeprecationWarning:
- self.fail()
-
def test_issue_13(self):
- """ can't handle large result fields """
- conn = self.connections[0]
+ """can't handle large result fields"""
+ conn = self.connect()
cur = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
@@ -106,7 +100,7 @@ def test_issue_13(self):
try:
cur.execute("create table issue13 (t text)")
# ticket says 18k
- size = 18*1024
+ size = 18 * 1024
cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
cur.execute("select t from issue13")
# use assertTrue so that obscenely huge error messages don't print
@@ -116,41 +110,47 @@ def test_issue_13(self):
cur.execute("drop table issue13")
def test_issue_15(self):
- """ query should be expanded before perform character encoding """
- conn = self.connections[0]
+ """query should be expanded before perform character encoding"""
+ conn = self.connect()
c = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue15")
c.execute("create table issue15 (t varchar(32))")
try:
- c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',))
+ c.execute("insert into issue15 (t) values (%s)", ("\xe4\xf6\xfc",))
c.execute("select t from issue15")
- self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
+ self.assertEqual("\xe4\xf6\xfc", c.fetchone()[0])
finally:
c.execute("drop table issue15")
def test_issue_16(self):
- """ Patch for string and tuple escaping """
- conn = self.connections[0]
+ """Patch for string and tuple escaping"""
+ conn = self.connect()
c = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue16")
- c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
+ c.execute(
+ "create table issue16 (name varchar(32) primary key, email varchar(32))"
+ )
try:
- c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
+ c.execute(
+ "insert into issue16 (name, email) values ('pete', 'floydophone')"
+ )
c.execute("select email from issue16 where name=%s", ("pete",))
self.assertEqual("floydophone", c.fetchone()[0])
finally:
c.execute("drop table issue16")
- @unittest2.skip("test_issue_17() requires a custom, legacy MySQL configuration and will not be run.")
+ @pytest.mark.skip(
+ "test_issue_17() requires a custom, legacy MySQL configuration and will not be run."
+ )
def test_issue_17(self):
- """could not connect mysql use passwod"""
- conn = self.connections[0]
+ """could not connect mysql use password"""
+ conn = self.connect()
host = self.databases[0]["host"]
- db = self.databases[0]["db"]
+ db = self.databases[0]["database"]
c = conn.cursor()
# grant access to a table to a user with a password
@@ -160,7 +160,10 @@ def test_issue_17(self):
c.execute("drop table if exists issue17")
c.execute("create table issue17 (x varchar(32) primary key)")
c.execute("insert into issue17 (x) values ('hello, world!')")
- c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db)
+ c.execute(
+ "grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'"
+ % db
+ )
conn.commit()
conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
@@ -170,6 +173,7 @@ def test_issue_17(self):
finally:
c.execute("drop table issue17")
+
class TestNewIssues(base.PyMySQLTestCase):
def test_issue_34(self):
try:
@@ -182,16 +186,17 @@ def test_issue_34(self):
def test_issue_33(self):
conn = pymysql.connect(charset="utf8", **self.databases[0])
- self.safe_create_table(conn, u'hei\xdfe',
- u'create table hei\xdfe (name varchar(32))')
+ self.safe_create_table(
+ conn, "hei\xdfe", "create table hei\xdfe (name varchar(32))"
+ )
c = conn.cursor()
- c.execute(u"insert into hei\xdfe (name) values ('Pi\xdfata')")
- c.execute(u"select name from hei\xdfe")
- self.assertEqual(u"Pi\xdfata", c.fetchone()[0])
+ c.execute("insert into hei\xdfe (name) values ('Pi\xdfata')")
+ c.execute("select name from hei\xdfe")
+ self.assertEqual("Pi\xdfata", c.fetchone()[0])
- @unittest2.skip("This test requires manual intervention")
+ @pytest.mark.skip("This test requires manual intervention")
def test_issue_35(self):
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
print("sudo killall -9 mysqld within the next 10 seconds")
try:
@@ -237,7 +242,7 @@ def test_issue_36(self):
del self.connections[1]
def test_issue_37(self):
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
self.assertEqual(1, c.execute("SELECT @foo"))
self.assertEqual((None,), c.fetchone())
@@ -245,9 +250,9 @@ def test_issue_37(self):
c.execute("set @foo = 'bar'")
def test_issue_38(self):
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
- datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
+ datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
try:
with warnings.catch_warnings():
@@ -259,13 +264,13 @@ def test_issue_38(self):
c.execute("drop table issue38")
def disabled_test_issue_54(self):
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue54")
big_sql = "select * from issue54 where "
- big_sql += " and ".join("%d=%d" % (i,i) for i in range(0, 100000))
+ big_sql += " and ".join("%d=%d" % (i, i) for i in range(0, 100000))
try:
c.execute("create table issue54 (id integer primary key)")
@@ -275,17 +280,20 @@ def disabled_test_issue_54(self):
finally:
c.execute("drop table issue54")
+
class TestGitHubIssues(base.PyMySQLTestCase):
def test_issue_66(self):
- """ 'Connection' object has no attribute 'insert_id' """
- conn = self.connections[0]
+ """'Connection' object has no attribute 'insert_id'"""
+ conn = self.connect()
c = conn.cursor()
self.assertEqual(0, conn.insert_id())
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue66")
- c.execute("create table issue66 (id integer primary key auto_increment, x integer)")
+ c.execute(
+ "create table issue66 (id integer primary key auto_increment, x integer)"
+ )
c.execute("insert into issue66 (x) values (1)")
c.execute("insert into issue66 (x) values (1)")
self.assertEqual(2, conn.insert_id())
@@ -293,8 +301,8 @@ def test_issue_66(self):
c.execute("drop table issue66")
def test_issue_79(self):
- """ Duplicate field overwrites the previous one in the result of DictCursor """
- conn = self.connections[0]
+ """Duplicate field overwrites the previous one in the result of DictCursor"""
+ conn = self.connect()
c = conn.cursor(pymysql.cursors.DictCursor)
with warnings.catch_warnings():
@@ -304,32 +312,34 @@ def test_issue_79(self):
c.execute("""CREATE TABLE a (id int, value int)""")
c.execute("""CREATE TABLE b (id int, value int)""")
- a=(1,11)
- b=(1,22)
+ a = (1, 11)
+ b = (1, 22)
try:
c.execute("insert into a values (%s, %s)", a)
c.execute("insert into b values (%s, %s)", b)
c.execute("SELECT * FROM a inner join b on a.id = b.id")
r = c.fetchall()[0]
- self.assertEqual(r['id'], 1)
- self.assertEqual(r['value'], 11)
- self.assertEqual(r['b.value'], 22)
+ self.assertEqual(r["id"], 1)
+ self.assertEqual(r["value"], 11)
+ self.assertEqual(r["b.value"], 22)
finally:
c.execute("drop table a")
c.execute("drop table b")
def test_issue_95(self):
- """ Leftover trailing OK packet for "CALL my_sp" queries """
- conn = self.connections[0]
+ """Leftover trailing OK packet for "CALL my_sp" queries"""
+ conn = self.connect()
cur = conn.cursor()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
cur.execute("DROP PROCEDURE IF EXISTS `foo`")
- cur.execute("""CREATE PROCEDURE `foo` ()
+ cur.execute(
+ """CREATE PROCEDURE `foo` ()
BEGIN
SELECT 1;
- END""")
+ END"""
+ )
try:
cur.execute("""CALL foo()""")
cur.execute("""SELECT 1""")
@@ -340,7 +350,7 @@ def test_issue_95(self):
cur.execute("DROP PROCEDURE IF EXISTS `foo`")
def test_issue_114(self):
- """ autocommit is not set after reconnecting with ping() """
+ """autocommit is not set after reconnecting with ping()"""
conn = pymysql.connect(charset="utf8", **self.databases[0])
conn.autocommit(False)
c = conn.cursor()
@@ -365,59 +375,62 @@ def test_issue_114(self):
conn.close()
def test_issue_175(self):
- """ The number of fields returned by server is read in wrong way """
- conn = self.connections[0]
+ """The number of fields returned by server is read in wrong way"""
+ conn = self.connect()
cur = conn.cursor()
for length in (200, 300):
- columns = ', '.join('c{0} integer'.format(i) for i in range(length))
- sql = 'create table test_field_count ({0})'.format(columns)
+ columns = ", ".join(f"c{i} integer" for i in range(length))
+ sql = f"create table test_field_count ({columns})"
try:
cur.execute(sql)
- cur.execute('select * from test_field_count')
+ cur.execute("select * from test_field_count")
assert len(cur.description) == length
finally:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
- cur.execute('drop table if exists test_field_count')
+ cur.execute("drop table if exists test_field_count")
def test_issue_321(self):
- """ Test iterable as query argument. """
+ """Test iterable as query argument."""
conn = pymysql.connect(charset="utf8", **self.databases[0])
self.safe_create_table(
- conn, "issue321",
- "create table issue321 (value_1 varchar(1), value_2 varchar(1))")
+ conn,
+ "issue321",
+ "create table issue321 (value_1 varchar(1), value_2 varchar(1))",
+ )
sql_insert = "insert into issue321 (value_1, value_2) values (%s, %s)"
- sql_dict_insert = ("insert into issue321 (value_1, value_2) "
- "values (%(value_1)s, %(value_2)s)")
- sql_select = ("select * from issue321 where "
- "value_1 in %s and value_2=%s")
+ sql_dict_insert = (
+ "insert into issue321 (value_1, value_2) values (%(value_1)s, %(value_2)s)"
+ )
+ sql_select = "select * from issue321 where value_1 in %s and value_2=%s"
data = [
- [(u"a", ), u"\u0430"],
- [[u"b"], u"\u0430"],
- {"value_1": [[u"c"]], "value_2": u"\u0430"}
+ [("a",), "\u0430"],
+ [["b"], "\u0430"],
+ {"value_1": [["c"]], "value_2": "\u0430"},
]
cur = conn.cursor()
self.assertEqual(cur.execute(sql_insert, data[0]), 1)
self.assertEqual(cur.execute(sql_insert, data[1]), 1)
self.assertEqual(cur.execute(sql_dict_insert, data[2]), 1)
- self.assertEqual(
- cur.execute(sql_select, [(u"a", u"b", u"c"), u"\u0430"]), 3)
- self.assertEqual(cur.fetchone(), (u"a", u"\u0430"))
- self.assertEqual(cur.fetchone(), (u"b", u"\u0430"))
- self.assertEqual(cur.fetchone(), (u"c", u"\u0430"))
+ self.assertEqual(cur.execute(sql_select, [("a", "b", "c"), "\u0430"]), 3)
+ self.assertEqual(cur.fetchone(), ("a", "\u0430"))
+ self.assertEqual(cur.fetchone(), ("b", "\u0430"))
+ self.assertEqual(cur.fetchone(), ("c", "\u0430"))
def test_issue_364(self):
- """ Test mixed unicode/binary arguments in executemany. """
+ """Test mixed unicode/binary arguments in executemany."""
conn = pymysql.connect(charset="utf8mb4", **self.databases[0])
self.safe_create_table(
- conn, "issue364",
+ conn,
+ "issue364",
"create table issue364 (value_1 binary(3), value_2 varchar(3)) "
- "engine=InnoDB default charset=utf8mb4")
+ "engine=InnoDB default charset=utf8mb4",
+ )
sql = "insert into issue364 (value_1, value_2) values (_binary %s, %s)"
- usql = u"insert into issue364 (value_1, value_2) values (_binary %s, %s)"
- values = [pymysql.Binary(b"\x00\xff\x00"), u"\xe4\xf6\xfc"]
+ usql = "insert into issue364 (value_1, value_2) values (_binary %s, %s)"
+ values = [pymysql.Binary(b"\x00\xff\x00"), "\xe4\xf6\xfc"]
# test single insert and select
cur = conn.cursor()
@@ -438,45 +451,44 @@ def test_issue_364(self):
cur.executemany(usql, args=(values, values, values))
def test_issue_363(self):
- """ Test binary / geometry types. """
+ """Test binary / geometry types."""
conn = pymysql.connect(charset="utf8", **self.databases[0])
self.safe_create_table(
- conn, "issue363",
+ conn,
+ "issue363",
"CREATE TABLE issue363 ( "
"id INTEGER PRIMARY KEY, geom LINESTRING NOT NULL /*!80003 SRID 0 */, "
"SPATIAL KEY geom (geom)) "
- "ENGINE=MyISAM")
+ "ENGINE=MyISAM",
+ )
cur = conn.cursor()
- # From MySQL 5.7, ST_GeomFromText is added and GeomFromText is deprecated.
- if self.mysql_server_is(conn, (5, 7, 0)):
- geom_from_text = "ST_GeomFromText"
- geom_as_text = "ST_AsText"
- geom_as_bin = "ST_AsBinary"
- else:
- geom_from_text = "GeomFromText"
- geom_as_text = "AsText"
- geom_as_bin = "AsBinary"
- query = ("INSERT INTO issue363 (id, geom) VALUES"
- "(1998, %s('LINESTRING(1.1 1.1,2.2 2.2)'))" % geom_from_text)
+ query = (
+ "INSERT INTO issue363 (id, geom) VALUES"
+ "(1998, ST_GeomFromText('LINESTRING(1.1 1.1,2.2 2.2)'))"
+ )
cur.execute(query)
# select WKT
- query = "SELECT %s(geom) FROM issue363" % geom_as_text
+ query = "SELECT ST_AsText(geom) FROM issue363"
cur.execute(query)
row = cur.fetchone()
- self.assertEqual(row, ("LINESTRING(1.1 1.1,2.2 2.2)", ))
+ self.assertEqual(row, ("LINESTRING(1.1 1.1,2.2 2.2)",))
# select WKB
- query = "SELECT %s(geom) FROM issue363" % geom_as_bin
+ query = "SELECT ST_AsBinary(geom) FROM issue363"
cur.execute(query)
row = cur.fetchone()
- self.assertEqual(row,
- (b"\x01\x02\x00\x00\x00\x02\x00\x00\x00"
- b"\x9a\x99\x99\x99\x99\x99\xf1?"
- b"\x9a\x99\x99\x99\x99\x99\xf1?"
- b"\x9a\x99\x99\x99\x99\x99\x01@"
- b"\x9a\x99\x99\x99\x99\x99\x01@", ))
+ self.assertEqual(
+ row,
+ (
+ b"\x01\x02\x00\x00\x00\x02\x00\x00\x00"
+ b"\x9a\x99\x99\x99\x99\x99\xf1?"
+ b"\x9a\x99\x99\x99\x99\x99\xf1?"
+ b"\x9a\x99\x99\x99\x99\x99\x01@"
+ b"\x9a\x99\x99\x99\x99\x99\x01@",
+ ),
+ )
# select internal binary
cur.execute("SELECT geom FROM issue363")
@@ -484,28 +496,3 @@ def test_issue_363(self):
# don't assert the exact internal binary value, as it could
# vary across implementations
self.assertTrue(isinstance(row[0], bytes))
-
- def test_issue_491(self):
- """ Test warning propagation """
- conn = pymysql.connect(charset="utf8", **self.databases[0])
-
- with warnings.catch_warnings():
- # Ignore all warnings other than pymysql generated ones
- warnings.simplefilter("ignore")
- warnings.simplefilter("error", category=pymysql.Warning)
-
- # verify for both buffered and unbuffered cursor types
- for cursor_class in (cursors.Cursor, cursors.SSCursor):
- c = conn.cursor(cursor_class)
- try:
- c.execute("SELECT CAST('124b' AS SIGNED)")
- c.fetchall()
- except pymysql.Warning as e:
- # Warnings should have errorcode and string message, just like exceptions
- self.assertEqual(len(e.args), 2)
- self.assertEqual(e.args[0], 1292)
- self.assertTrue(isinstance(e.args[1], text_type))
- else:
- self.fail("Should raise Warning")
- finally:
- c.close()
diff --git a/pymysql/tests/test_load_local.py b/pymysql/tests/test_load_local.py
index 85fd94ea5..509221420 100644
--- a/pymysql/tests/test_load_local.py
+++ b/pymysql/tests/test_load_local.py
@@ -1,8 +1,8 @@
-from pymysql import cursors, OperationalError, Warning
+from pymysql import cursors, OperationalError
+from pymysql.constants import ER
from pymysql.tests import base
import os
-import warnings
__all__ = ["TestLoadLocal"]
@@ -10,15 +10,17 @@
class TestLoadLocal(base.PyMySQLTestCase):
def test_no_file(self):
"""Test load local infile when the file does not exist"""
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
try:
self.assertRaises(
OperationalError,
c.execute,
- ("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
- "test_load_local fields terminated by ','")
+ (
+ "LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
+ "test_load_local fields terminated by ','"
+ ),
)
finally:
c.execute("DROP TABLE test_load_local")
@@ -26,16 +28,16 @@ def test_no_file(self):
def test_load_file(self):
"""Test load local infile with a valid file"""
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
- filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
- 'data',
- 'load_local_data.txt')
+ filename = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "data", "load_local_data.txt"
+ )
try:
c.execute(
- ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
- "test_load_local FIELDS TERMINATED BY ','").format(filename)
+ f"LOAD DATA LOCAL INFILE '{filename}' INTO TABLE test_load_local"
+ + " FIELDS TERMINATED BY ','"
)
c.execute("SELECT COUNT(*) FROM test_load_local")
self.assertEqual(22749, c.fetchone()[0])
@@ -44,16 +46,16 @@ def test_load_file(self):
def test_unbuffered_load_file(self):
"""Test unbuffered load local infile with a valid file"""
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor(cursors.SSCursor)
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
- filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
- 'data',
- 'load_local_data.txt')
+ filename = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "data", "load_local_data.txt"
+ )
try:
c.execute(
- ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
- "test_load_local FIELDS TERMINATED BY ','").format(filename)
+ f"LOAD DATA LOCAL INFILE '{filename}' INTO TABLE test_load_local"
+ + " FIELDS TERMINATED BY ','"
)
c.execute("SELECT COUNT(*) FROM test_load_local")
self.assertEqual(22749, c.fetchone()[0])
@@ -66,23 +68,31 @@ def test_unbuffered_load_file(self):
def test_load_warnings(self):
"""Test load local infile produces the appropriate warnings"""
- conn = self.connections[0]
+ conn = self.connect()
c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
- filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
- 'data',
- 'load_local_warn_data.txt')
+ filename = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "data",
+ "load_local_warn_data.txt",
+ )
try:
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
- c.execute(
- ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
- "test_load_local FIELDS TERMINATED BY ','").format(filename)
- )
- self.assertEqual(w[0].category, Warning)
- expected_message = "Incorrect integer value"
- if expected_message not in str(w[-1].message):
- self.fail("%r not in %r" % (expected_message, w[-1].message))
+ c.execute(
+ (
+ "LOAD DATA LOCAL INFILE '{0}' INTO TABLE "
+ + "test_load_local FIELDS TERMINATED BY ','"
+ ).format(filename)
+ )
+ self.assertEqual(1, c.warning_count)
+
+ c.execute("SHOW WARNINGS")
+ w = c.fetchone()
+
+ self.assertEqual(ER.TRUNCATED_WRONG_VALUE_FOR_FIELD, w[1])
+ self.assertIn(
+ "incorrect integer value",
+ w[2].lower(),
+ )
finally:
c.execute("DROP TABLE test_load_local")
c.close()
@@ -90,4 +100,5 @@ def test_load_warnings(self):
if __name__ == "__main__":
import unittest
+
unittest.main()
diff --git a/pymysql/tests/test_nextset.py b/pymysql/tests/test_nextset.py
index 998441072..a10f8d5b7 100644
--- a/pymysql/tests/test_nextset.py
+++ b/pymysql/tests/test_nextset.py
@@ -1,17 +1,16 @@
-import unittest2
+import pytest
import pymysql
-from pymysql import util
from pymysql.tests import base
from pymysql.constants import CLIENT
class TestNextset(base.PyMySQLTestCase):
-
def test_nextset(self):
con = self.connect(
init_command='SELECT "bar"; SELECT "baz"',
- client_flag=CLIENT.MULTI_STATEMENTS)
+ client_flag=CLIENT.MULTI_STATEMENTS,
+ )
cur = con.cursor()
cur.execute("SELECT 1; SELECT 2;")
self.assertEqual([(1,)], list(cur))
@@ -39,7 +38,7 @@ def test_nextset_error(self):
self.assertEqual([(i,)], list(cur))
with self.assertRaises(pymysql.ProgrammingError):
cur.nextset()
- self.assertEqual((), cur.fetchall())
+ self.assertEqual([], cur.fetchall())
def test_ok_and_next(self):
cur = self.connect(client_flag=CLIENT.MULTI_STATEMENTS).cursor()
@@ -50,7 +49,7 @@ def test_ok_and_next(self):
self.assertEqual([(2,)], list(cur))
self.assertFalse(bool(cur.nextset()))
- @unittest2.expectedFailure
+ @pytest.mark.xfail
def test_multi_cursor(self):
con = self.connect(client_flag=CLIENT.MULTI_STATEMENTS)
cur1 = con.cursor()
@@ -71,14 +70,14 @@ def test_multi_cursor(self):
def test_multi_statement_warnings(self):
con = self.connect(
init_command='SELECT "bar"; SELECT "baz"',
- client_flag=CLIENT.MULTI_STATEMENTS)
+ client_flag=CLIENT.MULTI_STATEMENTS,
+ )
cursor = con.cursor()
try:
- cursor.execute('DROP TABLE IF EXISTS a; '
- 'DROP TABLE IF EXISTS b;')
+ cursor.execute("DROP TABLE IF EXISTS a; DROP TABLE IF EXISTS b;")
except TypeError:
self.fail()
- #TODO: How about SSCursor and nextset?
+ # TODO: How about SSCursor and nextset?
# It's very hard to implement correctly...
diff --git a/pymysql/tests/test_optionfile.py b/pymysql/tests/test_optionfile.py
index 3ee519e2b..d13553dda 100644
--- a/pymysql/tests/test_optionfile.py
+++ b/pymysql/tests/test_optionfile.py
@@ -1,33 +1,24 @@
-from pymysql.optionfile import Parser
+from io import StringIO
from unittest import TestCase
-from pymysql._compat import PY2
-
-try:
- from cStringIO import StringIO
-except ImportError:
- from io import StringIO
+from pymysql.optionfile import Parser
-__all__ = ['TestParser']
+__all__ = ["TestParser"]
-_cfg_file = (r"""
+_cfg_file = r"""
[default]
string = foo
quoted = "bar"
single_quoted = 'foobar'
skip-slave-start
-""")
+"""
class TestParser(TestCase):
-
def test_string(self):
parser = Parser()
- if PY2:
- parser.readfp(StringIO(_cfg_file))
- else:
- parser.read_file(StringIO(_cfg_file))
+ parser.read_file(StringIO(_cfg_file))
self.assertEqual(parser.get("default", "string"), "foo")
self.assertEqual(parser.get("default", "quoted"), "bar")
- self.assertEqual(parser.get("default", "single_quoted"), "foobar")
+ self.assertEqual(parser.get("default", "single-quoted"), "foobar")
diff --git a/pymysql/tests/thirdparty/__init__.py b/pymysql/tests/thirdparty/__init__.py
index 6d59e1127..d5f053711 100644
--- a/pymysql/tests/thirdparty/__init__.py
+++ b/pymysql/tests/thirdparty/__init__.py
@@ -1,8 +1,6 @@
from .test_MySQLdb import *
if __name__ == "__main__":
- try:
- import unittest2 as unittest
- except ImportError:
- import unittest
+ import unittest
+
unittest.main()
diff --git a/pymysql/tests/thirdparty/test_MySQLdb/__init__.py b/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
index e4237c69a..501bfd2db 100644
--- a/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
+++ b/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
@@ -1,7 +1,6 @@
-from .test_MySQLdb_capabilities import test_MySQLdb as test_capabilities
from .test_MySQLdb_nonstandard import *
-from .test_MySQLdb_dbapi20 import test_MySQLdb as test_dbapi2
if __name__ == "__main__":
import unittest
+
unittest.main()
diff --git a/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
index bcf9eecbf..bb47cc5f6 100644
--- a/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
+++ b/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
@@ -4,17 +4,11 @@
Adapted from a script by M-A Lemburg.
"""
-import sys
from time import time
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
+import unittest
-PY2 = sys.version_info[0] == 2
class DatabaseTest(unittest.TestCase):
-
db_module = None
connect_args = ()
connect_kwargs = dict(use_unicode=True, charset="utf8mb4", binary_prefix=True)
@@ -26,11 +20,8 @@ def setUp(self):
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
self.connection = db
self.cursor = db.cursor()
- self.BLOBText = ''.join([chr(i) for i in range(256)] * 100);
- if PY2:
- self.BLOBUText = unicode().join(unichr(i) for i in range(16834))
- else:
- self.BLOBUText = "".join(chr(i) for i in range(16834))
+ self.BLOBText = "".join([chr(i) for i in range(256)] * 100)
+ self.BLOBUText = "".join(chr(i) for i in range(16834))
data = bytearray(range(256)) * 16
self.BLOBBinary = self.db_module.Binary(data)
@@ -39,17 +30,22 @@ def setUp(self):
def tearDown(self):
if self.leak_test:
import gc
+
del self.cursor
orphans = gc.collect()
- self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans)
+ self.assertFalse(
+ orphans, "%d orphaned objects found after deleting cursor" % orphans
+ )
del self.connection
orphans = gc.collect()
- self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans)
+ self.assertFalse(
+ orphans, "%d orphaned objects found after deleting connection" % orphans
+ )
def table_exists(self, name):
try:
- self.cursor.execute('select * from %s where 1=0' % name)
+ self.cursor.execute("select * from %s where 1=0" % name)
except Exception:
return False
else:
@@ -61,41 +57,41 @@ def quote_identifier(self, ident):
def new_table_name(self):
i = id(self.cursor)
while True:
- name = self.quote_identifier('tb%08x' % i)
+ name = self.quote_identifier("tb%08x" % i)
if not self.table_exists(name):
return name
i = i + 1
def create_table(self, columndefs):
+ """
+ Create a table using a list of column definitions given in columndefs.
- """ Create a table using a list of column definitions given in
- columndefs.
-
- generator must be a function taking arguments (row_number,
- col_number) returning a suitable data object for insertion
- into the table.
-
+ generator must be a function taking arguments (row_number,
+ col_number) returning a suitable data object for insertion
+ into the table.
"""
self.table = self.new_table_name()
- self.cursor.execute('CREATE TABLE %s (%s) %s' %
- (self.table,
- ',\n'.join(columndefs),
- self.create_table_extra))
+ self.cursor.execute(
+ "CREATE TABLE %s (%s) %s"
+ % (self.table, ",\n".join(columndefs), self.create_table_extra)
+ )
def check_data_integrity(self, columndefs, generator):
# insert
self.create_table(columndefs)
- insert_statement = ('INSERT INTO %s VALUES (%s)' %
- (self.table,
- ','.join(['%s'] * len(columndefs))))
- data = [ [ generator(i,j) for j in range(len(columndefs)) ]
- for i in range(self.rows) ]
+ insert_statement = "INSERT INTO %s VALUES (%s)" % (
+ self.table,
+ ",".join(["%s"] * len(columndefs)),
+ )
+ data = [
+ [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows)
+ ]
if self.debug:
print(data)
self.cursor.executemany(insert_statement, data)
self.connection.commit()
# verify
- self.cursor.execute('select * from %s' % self.table)
+ self.cursor.execute("select * from %s" % self.table)
l = self.cursor.fetchall()
if self.debug:
print(l)
@@ -103,62 +99,74 @@ def check_data_integrity(self, columndefs, generator):
try:
for i in range(self.rows):
for j in range(len(columndefs)):
- self.assertEqual(l[i][j], generator(i,j))
+ self.assertEqual(l[i][j], generator(i, j))
finally:
if not self.debug:
- self.cursor.execute('drop table %s' % (self.table))
+ self.cursor.execute("drop table %s" % (self.table))
def test_transactions(self):
- columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
+ columndefs = ("col1 INT", "col2 VARCHAR(255)")
+
def generator(row, col):
- if col == 0: return row
- else: return ('%i' % (row%10))*255
+ if col == 0:
+ return row
+ else:
+ return ("%i" % (row % 10)) * 255
+
self.create_table(columndefs)
- insert_statement = ('INSERT INTO %s VALUES (%s)' %
- (self.table,
- ','.join(['%s'] * len(columndefs))))
- data = [ [ generator(i,j) for j in range(len(columndefs)) ]
- for i in range(self.rows) ]
+ insert_statement = "INSERT INTO %s VALUES (%s)" % (
+ self.table,
+ ",".join(["%s"] * len(columndefs)),
+ )
+ data = [
+ [generator(i, j) for j in range(len(columndefs))] for i in range(self.rows)
+ ]
self.cursor.executemany(insert_statement, data)
# verify
self.connection.commit()
- self.cursor.execute('select * from %s' % self.table)
+ self.cursor.execute("select * from %s" % self.table)
l = self.cursor.fetchall()
self.assertEqual(len(l), self.rows)
for i in range(self.rows):
for j in range(len(columndefs)):
- self.assertEqual(l[i][j], generator(i,j))
- delete_statement = 'delete from %s where col1=%%s' % self.table
+ self.assertEqual(l[i][j], generator(i, j))
+ delete_statement = "delete from %s where col1=%%s" % self.table
self.cursor.execute(delete_statement, (0,))
- self.cursor.execute('select col1 from %s where col1=%s' % \
- (self.table, 0))
+ self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0))
l = self.cursor.fetchall()
self.assertFalse(l, "DELETE didn't work")
self.connection.rollback()
- self.cursor.execute('select col1 from %s where col1=%s' % \
- (self.table, 0))
+ self.cursor.execute("select col1 from %s where col1=%s" % (self.table, 0))
l = self.cursor.fetchall()
self.assertTrue(len(l) == 1, "ROLLBACK didn't work")
- self.cursor.execute('drop table %s' % (self.table))
+ self.cursor.execute("drop table %s" % (self.table))
def test_truncation(self):
- columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
+ columndefs = ("col1 INT", "col2 VARCHAR(255)")
+
def generator(row, col):
- if col == 0: return row
- else: return ('%i' % (row%10))*((255-self.rows//2)+row)
+ if col == 0:
+ return row
+ else:
+ return ("%i" % (row % 10)) * ((255 - self.rows // 2) + row)
+
self.create_table(columndefs)
- insert_statement = ('INSERT INTO %s VALUES (%s)' %
- (self.table,
- ','.join(['%s'] * len(columndefs))))
+ insert_statement = "INSERT INTO %s VALUES (%s)" % (
+ self.table,
+ ",".join(["%s"] * len(columndefs)),
+ )
try:
- self.cursor.execute(insert_statement, (0, '0'*256))
+ self.cursor.execute(insert_statement, (0, "0" * 256))
except Warning:
- if self.debug: print(self.cursor.messages)
+ if self.debug:
+ print(self.cursor.messages)
except self.connection.DataError:
pass
else:
- self.fail("Over-long column did not generate warnings/exception with single insert")
+ self.fail(
+ "Over-long column did not generate warnings/exception with single insert"
+ )
self.connection.rollback()
@@ -166,132 +174,136 @@ def generator(row, col):
for i in range(self.rows):
data = []
for j in range(len(columndefs)):
- data.append(generator(i,j))
- self.cursor.execute(insert_statement,tuple(data))
+ data.append(generator(i, j))
+ self.cursor.execute(insert_statement, tuple(data))
except Warning:
- if self.debug: print(self.cursor.messages)
+ if self.debug:
+ print(self.cursor.messages)
except self.connection.DataError:
pass
else:
- self.fail("Over-long columns did not generate warnings/exception with execute()")
+ self.fail(
+ "Over-long columns did not generate warnings/exception with execute()"
+ )
self.connection.rollback()
try:
- data = [ [ generator(i,j) for j in range(len(columndefs)) ]
- for i in range(self.rows) ]
+ data = [
+ [generator(i, j) for j in range(len(columndefs))]
+ for i in range(self.rows)
+ ]
self.cursor.executemany(insert_statement, data)
except Warning:
- if self.debug: print(self.cursor.messages)
+ if self.debug:
+ print(self.cursor.messages)
except self.connection.DataError:
pass
else:
- self.fail("Over-long columns did not generate warnings/exception with executemany()")
+ self.fail(
+ "Over-long columns did not generate warnings/exception with executemany()"
+ )
self.connection.rollback()
- self.cursor.execute('drop table %s' % (self.table))
+ self.cursor.execute("drop table %s" % (self.table))
def test_CHAR(self):
# Character data
- def generator(row,col):
- return ('%i' % ((row+col) % 10)) * 255
- self.check_data_integrity(
- ('col1 char(255)','col2 char(255)'),
- generator)
+ def generator(row, col):
+ return ("%i" % ((row + col) % 10)) * 255
+
+ self.check_data_integrity(("col1 char(255)", "col2 char(255)"), generator)
def test_INT(self):
# Number data
- def generator(row,col):
- return row*row
- self.check_data_integrity(
- ('col1 INT',),
- generator)
+ def generator(row, col):
+ return row * row
+
+ self.check_data_integrity(("col1 INT",), generator)
def test_DECIMAL(self):
# DECIMAL
- def generator(row,col):
+ def generator(row, col):
from decimal import Decimal
+
return Decimal("%d.%02d" % (row, col))
- self.check_data_integrity(
- ('col1 DECIMAL(5,2)',),
- generator)
+
+ self.check_data_integrity(("col1 DECIMAL(5,2)",), generator)
def test_DATE(self):
ticks = time()
- def generator(row,col):
- return self.db_module.DateFromTicks(ticks+row*86400-col*1313)
- self.check_data_integrity(
- ('col1 DATE',),
- generator)
+
+ def generator(row, col):
+ return self.db_module.DateFromTicks(ticks + row * 86400 - col * 1313)
+
+ self.check_data_integrity(("col1 DATE",), generator)
def test_TIME(self):
ticks = time()
- def generator(row,col):
- return self.db_module.TimeFromTicks(ticks+row*86400-col*1313)
- self.check_data_integrity(
- ('col1 TIME',),
- generator)
+
+ def generator(row, col):
+ return self.db_module.TimeFromTicks(ticks + row * 86400 - col * 1313)
+
+ self.check_data_integrity(("col1 TIME",), generator)
def test_DATETIME(self):
ticks = time()
- def generator(row,col):
- return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
- self.check_data_integrity(
- ('col1 DATETIME',),
- generator)
+
+ def generator(row, col):
+ return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313)
+
+ self.check_data_integrity(("col1 DATETIME",), generator)
def test_TIMESTAMP(self):
ticks = time()
- def generator(row,col):
- return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
- self.check_data_integrity(
- ('col1 TIMESTAMP',),
- generator)
+
+ def generator(row, col):
+ return self.db_module.TimestampFromTicks(ticks + row * 86400 - col * 1313)
+
+ self.check_data_integrity(("col1 TIMESTAMP",), generator)
def test_fractional_TIMESTAMP(self):
ticks = time()
- def generator(row,col):
- return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
- self.check_data_integrity(
- ('col1 TIMESTAMP',),
- generator)
+
+ def generator(row, col):
+ return self.db_module.TimestampFromTicks(
+ ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0
+ )
+
+ self.check_data_integrity(("col1 TIMESTAMP",), generator)
def test_LONG(self):
- def generator(row,col):
+ def generator(row, col):
if col == 0:
return row
else:
- return self.BLOBUText # 'BLOB Text ' * 1024
- self.check_data_integrity(
- ('col1 INT', 'col2 LONG'),
- generator)
+ return self.BLOBUText # 'BLOB Text ' * 1024
+
+ self.check_data_integrity(("col1 INT", "col2 LONG"), generator)
def test_TEXT(self):
- def generator(row,col):
+ def generator(row, col):
if col == 0:
return row
else:
- return self.BLOBUText[:5192] # 'BLOB Text ' * 1024
- self.check_data_integrity(
- ('col1 INT', 'col2 TEXT'),
- generator)
+ return self.BLOBUText[:5192] # 'BLOB Text ' * 1024
+
+ self.check_data_integrity(("col1 INT", "col2 TEXT"), generator)
def test_LONG_BYTE(self):
- def generator(row,col):
+ def generator(row, col):
if col == 0:
return row
else:
- return self.BLOBBinary # 'BLOB\000Binary ' * 1024
- self.check_data_integrity(
- ('col1 INT','col2 LONG BYTE'),
- generator)
+ return self.BLOBBinary # 'BLOB\000Binary ' * 1024
+
+ self.check_data_integrity(("col1 INT", "col2 LONG BYTE"), generator)
def test_BLOB(self):
- def generator(row,col):
+ def generator(row, col):
if col == 0:
return row
else:
- return self.BLOBBinary # 'BLOB\000Binary ' * 1024
- self.check_data_integrity(
- ('col1 INT','col2 BLOB'),
- generator)
+ return self.BLOBBinary # 'BLOB\000Binary ' * 1024
+
+ self.check_data_integrity(("col1 INT", "col2 BLOB"), generator)
diff --git a/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
index 3cbf2263d..fff14b86f 100644
--- a/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
+++ b/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
@@ -1,4 +1,4 @@
-''' Python DB API 2.0 driver compliance unit test suite.
+""" Python DB API 2.0 driver compliance unit test suite.
This software is Public Domain and may be used without restrictions.
@@ -8,18 +8,14 @@
this is turning out to be a thoroughly unwholesome unit test."
-- Ian Bicking
-'''
+"""
-__rcs_id__ = '$Id$'
-__version__ = '$Revision$'[11:-2]
-__author__ = 'Stuart Bishop '
-
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
+__rcs_id__ = "$Id$"
+__version__ = "$Revision$"[11:-2]
+__author__ = "Stuart Bishop "
import time
+import unittest
# $Log$
# Revision 1.1.2.1 2006/02/25 03:44:32 adustman
@@ -55,9 +51,9 @@
# - Now a subclass of TestCase, to avoid requiring the driver stub
# to use multiple inheritance
# - Reversed the polarity of buggy test in test_description
-# - Test exception heirarchy correctly
+# - Test exception hierarchy correctly
# - self.populate is now self._populate(), so if a driver stub
-# overrides self.ddl1 this change propogates
+# overrides self.ddl1 this change propagates
# - VARCHAR columns now have a width, which will hopefully make the
# DDL even more portible (this will be reversed if it causes more problems)
# - cursor.rowcount being checked after various execute and fetchXXX methods
@@ -67,65 +63,66 @@
# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
#
+
class DatabaseAPI20Test(unittest.TestCase):
- ''' Test a database self.driver for DB API 2.0 compatibility.
- This implementation tests Gadfly, but the TestCase
- is structured so that other self.drivers can subclass this
- test case to ensure compiliance with the DB-API. It is
- expected that this TestCase may be expanded in the future
- if ambiguities or edge conditions are discovered.
+ """Test a database self.driver for DB API 2.0 compatibility.
+ This implementation tests Gadfly, but the TestCase
+ is structured so that other self.drivers can subclass this
+ test case to ensure compiliance with the DB-API. It is
+ expected that this TestCase may be expanded in the future
+ if ambiguities or edge conditions are discovered.
- The 'Optional Extensions' are not yet being tested.
+ The 'Optional Extensions' are not yet being tested.
- self.drivers should subclass this test, overriding setUp, tearDown,
- self.driver, connect_args and connect_kw_args. Class specification
- should be as follows:
+ self.drivers should subclass this test, overriding setUp, tearDown,
+ self.driver, connect_args and connect_kw_args. Class specification
+ should be as follows:
- import dbapi20
- class mytest(dbapi20.DatabaseAPI20Test):
- [...]
+ import dbapi20
+ class mytest(dbapi20.DatabaseAPI20Test):
+ [...]
- Don't 'import DatabaseAPI20Test from dbapi20', or you will
- confuse the unit tester - just 'import dbapi20'.
- '''
+ Don't 'import DatabaseAPI20Test from dbapi20', or you will
+ confuse the unit tester - just 'import dbapi20'.
+ """
# The self.driver module. This should be the module where the 'connect'
# method is to be found
driver = None
- connect_args = () # List of arguments to pass to connect
- connect_kw_args = {} # Keyword arguments for connect
- table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
+ connect_args = () # List of arguments to pass to connect
+ connect_kw_args = {} # Keyword arguments for connect
+ table_prefix = "dbapi20test_" # If you need to specify a prefix for tables
- ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
- ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
- xddl1 = 'drop table %sbooze' % table_prefix
- xddl2 = 'drop table %sbarflys' % table_prefix
+ ddl1 = "create table %sbooze (name varchar(20))" % table_prefix
+ ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix
+ xddl1 = "drop table %sbooze" % table_prefix
+ xddl2 = "drop table %sbarflys" % table_prefix
- lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
+ lowerfunc = "lower" # Name of stored procedure to convert string->lowercase
# Some drivers may need to override these helpers, for example adding
# a 'commit' after the execute.
- def executeDDL1(self,cursor):
+ def executeDDL1(self, cursor):
cursor.execute(self.ddl1)
- def executeDDL2(self,cursor):
+ def executeDDL2(self, cursor):
cursor.execute(self.ddl2)
def setUp(self):
- ''' self.drivers should override this method to perform required setup
- if any is necessary, such as creating the database.
- '''
+ """self.drivers should override this method to perform required setup
+ if any is necessary, such as creating the database.
+ """
pass
def tearDown(self):
- ''' self.drivers should override this method to perform required cleanup
- if any is necessary, such as deleting the test database.
- The default drops the tables that may be created.
- '''
+ """self.drivers should override this method to perform required cleanup
+ if any is necessary, such as deleting the test database.
+ The default drops the tables that may be created.
+ """
con = self._connect()
try:
cur = con.cursor()
- for ddl in (self.xddl1,self.xddl2):
+ for ddl in (self.xddl1, self.xddl2):
try:
cur.execute(ddl)
con.commit()
@@ -138,9 +135,7 @@ def tearDown(self):
def _connect(self):
try:
- return self.driver.connect(
- *self.connect_args,**self.connect_kw_args
- )
+ return self.driver.connect(*self.connect_args, **self.connect_kw_args)
except AttributeError:
self.fail("No connect method found in self.driver module")
@@ -153,7 +148,7 @@ def test_apilevel(self):
# Must exist
apilevel = self.driver.apilevel
# Must equal 2.0
- self.assertEqual(apilevel,'2.0')
+ self.assertEqual(apilevel, "2.0")
except AttributeError:
self.fail("Driver doesn't define apilevel")
@@ -162,7 +157,7 @@ def test_threadsafety(self):
# Must exist
threadsafety = self.driver.threadsafety
# Must be a valid value
- self.assertTrue(threadsafety in (0,1,2,3))
+ self.assertTrue(threadsafety in (0, 1, 2, 3))
except AttributeError:
self.fail("Driver doesn't define threadsafety")
@@ -171,38 +166,24 @@ def test_paramstyle(self):
# Must exist
paramstyle = self.driver.paramstyle
# Must be a valid value
- self.assertTrue(paramstyle in (
- 'qmark','numeric','named','format','pyformat'
- ))
+ self.assertTrue(
+ paramstyle in ("qmark", "numeric", "named", "format", "pyformat")
+ )
except AttributeError:
self.fail("Driver doesn't define paramstyle")
def test_Exceptions(self):
# Make sure required exceptions exist, and are in the
- # defined heirarchy.
- self.assertTrue(issubclass(self.driver.Warning,Exception))
- self.assertTrue(issubclass(self.driver.Error,Exception))
- self.assertTrue(
- issubclass(self.driver.InterfaceError,self.driver.Error)
- )
- self.assertTrue(
- issubclass(self.driver.DatabaseError,self.driver.Error)
- )
- self.assertTrue(
- issubclass(self.driver.OperationalError,self.driver.Error)
- )
- self.assertTrue(
- issubclass(self.driver.IntegrityError,self.driver.Error)
- )
- self.assertTrue(
- issubclass(self.driver.InternalError,self.driver.Error)
- )
- self.assertTrue(
- issubclass(self.driver.ProgrammingError,self.driver.Error)
- )
- self.assertTrue(
- issubclass(self.driver.NotSupportedError,self.driver.Error)
- )
+ # defined hierarchy.
+ self.assertTrue(issubclass(self.driver.Warning, Exception))
+ self.assertTrue(issubclass(self.driver.Error, Exception))
+ self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error))
+ self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error))
+ self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error))
+ self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error))
+ self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error))
+ self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error))
+ self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error))
def test_ExceptionsAsConnectionAttributes(self):
# OPTIONAL EXTENSION
@@ -223,7 +204,6 @@ def test_ExceptionsAsConnectionAttributes(self):
self.assertTrue(con.ProgrammingError is drv.ProgrammingError)
self.assertTrue(con.NotSupportedError is drv.NotSupportedError)
-
def test_commit(self):
con = self._connect()
try:
@@ -236,7 +216,7 @@ def test_rollback(self):
con = self._connect()
# If rollback is defined, it should either work or throw
# the documented exception
- if hasattr(con,'rollback'):
+ if hasattr(con, "rollback"):
try:
con.rollback()
except self.driver.NotSupportedError:
@@ -245,7 +225,7 @@ def test_rollback(self):
def test_cursor(self):
con = self._connect()
try:
- cur = con.cursor()
+ con.cursor()
finally:
con.close()
@@ -257,14 +237,14 @@ def test_cursor_isolation(self):
cur1 = con.cursor()
cur2 = con.cursor()
self.executeDDL1(cur1)
- cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
- self.table_prefix
- ))
+ cur1.execute(
+ "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
+ )
cur2.execute("select name from %sbooze" % self.table_prefix)
booze = cur2.fetchall()
- self.assertEqual(len(booze),1)
- self.assertEqual(len(booze[0]),1)
- self.assertEqual(booze[0][0],'Victoria Bitter')
+ self.assertEqual(len(booze), 1)
+ self.assertEqual(len(booze[0]), 1)
+ self.assertEqual(booze[0][0], "Victoria Bitter")
finally:
con.close()
@@ -273,31 +253,41 @@ def test_description(self):
try:
cur = con.cursor()
self.executeDDL1(cur)
- self.assertEqual(cur.description,None,
- 'cursor.description should be none after executing a '
- 'statement that can return no rows (such as DDL)'
- )
- cur.execute('select name from %sbooze' % self.table_prefix)
- self.assertEqual(len(cur.description),1,
- 'cursor.description describes too many columns'
- )
- self.assertEqual(len(cur.description[0]),7,
- 'cursor.description[x] tuples must have 7 elements'
- )
- self.assertEqual(cur.description[0][0].lower(),'name',
- 'cursor.description[x][0] must return column name'
- )
- self.assertEqual(cur.description[0][1],self.driver.STRING,
- 'cursor.description[x][1] must return column type. Got %r'
- % cur.description[0][1]
- )
+ self.assertEqual(
+ cur.description,
+ None,
+ "cursor.description should be none after executing a "
+ "statement that can return no rows (such as DDL)",
+ )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertEqual(
+ len(cur.description), 1, "cursor.description describes too many columns"
+ )
+ self.assertEqual(
+ len(cur.description[0]),
+ 7,
+ "cursor.description[x] tuples must have 7 elements",
+ )
+ self.assertEqual(
+ cur.description[0][0].lower(),
+ "name",
+ "cursor.description[x][0] must return column name",
+ )
+ self.assertEqual(
+ cur.description[0][1],
+ self.driver.STRING,
+ "cursor.description[x][1] must return column type. Got %r"
+ % cur.description[0][1],
+ )
# Make sure self.description gets reset
self.executeDDL2(cur)
- self.assertEqual(cur.description,None,
- 'cursor.description not being set to None when executing '
- 'no-result statements (eg. DDL)'
- )
+ self.assertEqual(
+ cur.description,
+ None,
+ "cursor.description not being set to None when executing "
+ "no-result statements (eg. DDL)",
+ )
finally:
con.close()
@@ -306,47 +296,49 @@ def test_rowcount(self):
try:
cur = con.cursor()
self.executeDDL1(cur)
- self.assertEqual(cur.rowcount,-1,
- 'cursor.rowcount should be -1 after executing no-result '
- 'statements'
- )
- cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
- self.table_prefix
- ))
- self.assertTrue(cur.rowcount in (-1,1),
- 'cursor.rowcount should == number or rows inserted, or '
- 'set to -1 after executing an insert statement'
- )
+ self.assertEqual(
+ cur.rowcount,
+ -1,
+ "cursor.rowcount should be -1 after executing no-result statements",
+ )
+ cur.execute(
+ "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
+ )
+ self.assertTrue(
+ cur.rowcount in (-1, 1),
+ "cursor.rowcount should == number or rows inserted, or "
+ "set to -1 after executing an insert statement",
+ )
cur.execute("select name from %sbooze" % self.table_prefix)
- self.assertTrue(cur.rowcount in (-1,1),
- 'cursor.rowcount should == number of rows returned, or '
- 'set to -1 after executing a select statement'
- )
+ self.assertTrue(
+ cur.rowcount in (-1, 1),
+ "cursor.rowcount should == number of rows returned, or "
+ "set to -1 after executing a select statement",
+ )
self.executeDDL2(cur)
- self.assertEqual(cur.rowcount,-1,
- 'cursor.rowcount not being reset to -1 after executing '
- 'no-result statements'
- )
+ self.assertEqual(
+ cur.rowcount,
+ -1,
+ "cursor.rowcount not being reset to -1 after executing "
+ "no-result statements",
+ )
finally:
con.close()
- lower_func = 'lower'
+ lower_func = "lower"
+
def test_callproc(self):
con = self._connect()
try:
cur = con.cursor()
- if self.lower_func and hasattr(cur,'callproc'):
- r = cur.callproc(self.lower_func,('FOO',))
- self.assertEqual(len(r),1)
- self.assertEqual(r[0],'FOO')
+ if self.lower_func and hasattr(cur, "callproc"):
+ r = cur.callproc(self.lower_func, ("FOO",))
+ self.assertEqual(len(r), 1)
+ self.assertEqual(r[0], "FOO")
r = cur.fetchall()
- self.assertEqual(len(r),1,'callproc produced no result set')
- self.assertEqual(len(r[0]),1,
- 'callproc produced invalid result set'
- )
- self.assertEqual(r[0][0],'foo',
- 'callproc produced invalid results'
- )
+ self.assertEqual(len(r), 1, "callproc produced no result set")
+ self.assertEqual(len(r[0]), 1, "callproc produced invalid result set")
+ self.assertEqual(r[0][0], "foo", "callproc produced invalid results")
finally:
con.close()
@@ -359,14 +351,14 @@ def test_close(self):
# cursor.execute should raise an Error if called after connection
# closed
- self.assertRaises(self.driver.Error,self.executeDDL1,cur)
+ self.assertRaises(self.driver.Error, self.executeDDL1, cur)
# connection.commit should raise an Error if called after connection'
# closed.'
- self.assertRaises(self.driver.Error,con.commit)
+ self.assertRaises(self.driver.Error, con.commit)
# connection.close should raise an Error if called more than once
- self.assertRaises(self.driver.Error,con.close)
+ self.assertRaises(self.driver.Error, con.close)
def test_execute(self):
con = self._connect()
@@ -376,105 +368,99 @@ def test_execute(self):
finally:
con.close()
- def _paraminsert(self,cur):
+ def _paraminsert(self, cur):
self.executeDDL1(cur)
- cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
- self.table_prefix
- ))
- self.assertTrue(cur.rowcount in (-1,1))
+ cur.execute(
+ "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
+ )
+ self.assertTrue(cur.rowcount in (-1, 1))
- if self.driver.paramstyle == 'qmark':
+ if self.driver.paramstyle == "qmark":
cur.execute(
- 'insert into %sbooze values (?)' % self.table_prefix,
- ("Cooper's",)
- )
- elif self.driver.paramstyle == 'numeric':
+ "insert into %sbooze values (?)" % self.table_prefix, ("Cooper's",)
+ )
+ elif self.driver.paramstyle == "numeric":
cur.execute(
- 'insert into %sbooze values (:1)' % self.table_prefix,
- ("Cooper's",)
- )
- elif self.driver.paramstyle == 'named':
+ "insert into %sbooze values (:1)" % self.table_prefix, ("Cooper's",)
+ )
+ elif self.driver.paramstyle == "named":
cur.execute(
- 'insert into %sbooze values (:beer)' % self.table_prefix,
- {'beer':"Cooper's"}
- )
- elif self.driver.paramstyle == 'format':
+ "insert into %sbooze values (:beer)" % self.table_prefix,
+ {"beer": "Cooper's"},
+ )
+ elif self.driver.paramstyle == "format":
cur.execute(
- 'insert into %sbooze values (%%s)' % self.table_prefix,
- ("Cooper's",)
- )
- elif self.driver.paramstyle == 'pyformat':
+ "insert into %sbooze values (%%s)" % self.table_prefix, ("Cooper's",)
+ )
+ elif self.driver.paramstyle == "pyformat":
cur.execute(
- 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
- {'beer':"Cooper's"}
- )
+ "insert into %sbooze values (%%(beer)s)" % self.table_prefix,
+ {"beer": "Cooper's"},
+ )
else:
- self.fail('Invalid paramstyle')
- self.assertTrue(cur.rowcount in (-1,1))
+ self.fail("Invalid paramstyle")
+ self.assertTrue(cur.rowcount in (-1, 1))
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
res = cur.fetchall()
- self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
- beers = [res[0][0],res[1][0]]
+ self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows")
+ beers = [res[0][0], res[1][0]]
beers.sort()
- self.assertEqual(beers[0],"Cooper's",
- 'cursor.fetchall retrieved incorrect data, or data inserted '
- 'incorrectly'
- )
- self.assertEqual(beers[1],"Victoria Bitter",
- 'cursor.fetchall retrieved incorrect data, or data inserted '
- 'incorrectly'
- )
+ self.assertEqual(
+ beers[0],
+ "Cooper's",
+ "cursor.fetchall retrieved incorrect data, or data inserted incorrectly",
+ )
+ self.assertEqual(
+ beers[1],
+ "Victoria Bitter",
+ "cursor.fetchall retrieved incorrect data, or data inserted incorrectly",
+ )
def test_executemany(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
- largs = [ ("Cooper's",) , ("Boag's",) ]
- margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
- if self.driver.paramstyle == 'qmark':
+ largs = [("Cooper's",), ("Boag's",)]
+ margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}]
+ if self.driver.paramstyle == "qmark":
cur.executemany(
- 'insert into %sbooze values (?)' % self.table_prefix,
- largs
- )
- elif self.driver.paramstyle == 'numeric':
+ "insert into %sbooze values (?)" % self.table_prefix, largs
+ )
+ elif self.driver.paramstyle == "numeric":
cur.executemany(
- 'insert into %sbooze values (:1)' % self.table_prefix,
- largs
- )
- elif self.driver.paramstyle == 'named':
+ "insert into %sbooze values (:1)" % self.table_prefix, largs
+ )
+ elif self.driver.paramstyle == "named":
cur.executemany(
- 'insert into %sbooze values (:beer)' % self.table_prefix,
- margs
- )
- elif self.driver.paramstyle == 'format':
+ "insert into %sbooze values (:beer)" % self.table_prefix, margs
+ )
+ elif self.driver.paramstyle == "format":
cur.executemany(
- 'insert into %sbooze values (%%s)' % self.table_prefix,
- largs
- )
- elif self.driver.paramstyle == 'pyformat':
+ "insert into %sbooze values (%%s)" % self.table_prefix, largs
+ )
+ elif self.driver.paramstyle == "pyformat":
cur.executemany(
- 'insert into %sbooze values (%%(beer)s)' % (
- self.table_prefix
- ),
- margs
- )
- else:
- self.fail('Unknown paramstyle')
- self.assertTrue(cur.rowcount in (-1,2),
- 'insert using cursor.executemany set cursor.rowcount to '
- 'incorrect value %r' % cur.rowcount
+ "insert into %sbooze values (%%(beer)s)" % (self.table_prefix),
+ margs,
)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ else:
+ self.fail("Unknown paramstyle")
+ self.assertTrue(
+ cur.rowcount in (-1, 2),
+ "insert using cursor.executemany set cursor.rowcount to "
+ "incorrect value %r" % cur.rowcount,
+ )
+ cur.execute("select name from %sbooze" % self.table_prefix)
res = cur.fetchall()
- self.assertEqual(len(res),2,
- 'cursor.fetchall retrieved incorrect number of rows'
- )
- beers = [res[0][0],res[1][0]]
+ self.assertEqual(
+ len(res), 2, "cursor.fetchall retrieved incorrect number of rows"
+ )
+ beers = [res[0][0], res[1][0]]
beers.sort()
- self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
- self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
+ self.assertEqual(beers[0], "Boag's", "incorrect data retrieved")
+ self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved")
finally:
con.close()
@@ -485,59 +471,62 @@ def test_fetchone(self):
# cursor.fetchone should raise an Error if called before
# executing a select-type query
- self.assertRaises(self.driver.Error,cur.fetchone)
+ self.assertRaises(self.driver.Error, cur.fetchone)
# cursor.fetchone should raise an Error if called after
- # executing a query that cannnot return rows
+ # executing a query that cannot return rows
self.executeDDL1(cur)
- self.assertRaises(self.driver.Error,cur.fetchone)
+ self.assertRaises(self.driver.Error, cur.fetchone)
- cur.execute('select name from %sbooze' % self.table_prefix)
- self.assertEqual(cur.fetchone(),None,
- 'cursor.fetchone should return None if a query retrieves '
- 'no rows'
- )
- self.assertTrue(cur.rowcount in (-1,0))
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertEqual(
+ cur.fetchone(),
+ None,
+ "cursor.fetchone should return None if a query retrieves no rows",
+ )
+ self.assertTrue(cur.rowcount in (-1, 0))
# cursor.fetchone should raise an Error if called after
- # executing a query that cannnot return rows
- cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
- self.table_prefix
- ))
- self.assertRaises(self.driver.Error,cur.fetchone)
+ # executing a query that cannot return rows
+ cur.execute(
+ "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
+ )
+ self.assertRaises(self.driver.Error, cur.fetchone)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchone()
- self.assertEqual(len(r),1,
- 'cursor.fetchone should have retrieved a single row'
- )
- self.assertEqual(r[0],'Victoria Bitter',
- 'cursor.fetchone retrieved incorrect data'
- )
- self.assertEqual(cur.fetchone(),None,
- 'cursor.fetchone should return None if no more rows available'
- )
- self.assertTrue(cur.rowcount in (-1,1))
+ self.assertEqual(
+ len(r), 1, "cursor.fetchone should have retrieved a single row"
+ )
+ self.assertEqual(
+ r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data"
+ )
+ self.assertEqual(
+ cur.fetchone(),
+ None,
+ "cursor.fetchone should return None if no more rows available",
+ )
+ self.assertTrue(cur.rowcount in (-1, 1))
finally:
con.close()
samples = [
- 'Carlton Cold',
- 'Carlton Draft',
- 'Mountain Goat',
- 'Redback',
- 'Victoria Bitter',
- 'XXXX'
- ]
+ "Carlton Cold",
+ "Carlton Draft",
+ "Mountain Goat",
+ "Redback",
+ "Victoria Bitter",
+ "XXXX",
+ ]
def _populate(self):
- ''' Return a list of sql commands to setup the DB for the fetch
- tests.
- '''
+ """Return a list of sql commands to setup the DB for the fetch
+ tests.
+ """
populate = [
- "insert into %sbooze values ('%s')" % (self.table_prefix,s)
- for s in self.samples
- ]
+ "insert into %sbooze values ('%s')" % (self.table_prefix, s)
+ for s in self.samples
+ ]
return populate
def test_fetchmany(self):
@@ -546,78 +535,88 @@ def test_fetchmany(self):
cur = con.cursor()
# cursor.fetchmany should raise an Error if called without
- #issuing a query
- self.assertRaises(self.driver.Error,cur.fetchmany,4)
+ # issuing a query
+ self.assertRaises(self.driver.Error, cur.fetchmany, 4)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchmany()
- self.assertEqual(len(r),1,
- 'cursor.fetchmany retrieved incorrect number of rows, '
- 'default of arraysize is one.'
- )
- cur.arraysize=10
- r = cur.fetchmany(3) # Should get 3 rows
- self.assertEqual(len(r),3,
- 'cursor.fetchmany retrieved incorrect number of rows'
- )
- r = cur.fetchmany(4) # Should get 2 more
- self.assertEqual(len(r),2,
- 'cursor.fetchmany retrieved incorrect number of rows'
- )
- r = cur.fetchmany(4) # Should be an empty sequence
- self.assertEqual(len(r),0,
- 'cursor.fetchmany should return an empty sequence after '
- 'results are exhausted'
+ self.assertEqual(
+ len(r),
+ 1,
+ "cursor.fetchmany retrieved incorrect number of rows, "
+ "default of arraysize is one.",
+ )
+ cur.arraysize = 10
+ r = cur.fetchmany(3) # Should get 3 rows
+ self.assertEqual(
+ len(r), 3, "cursor.fetchmany retrieved incorrect number of rows"
+ )
+ r = cur.fetchmany(4) # Should get 2 more
+ self.assertEqual(
+ len(r), 2, "cursor.fetchmany retrieved incorrect number of rows"
+ )
+ r = cur.fetchmany(4) # Should be an empty sequence
+ self.assertEqual(
+ len(r),
+ 0,
+ "cursor.fetchmany should return an empty sequence after "
+ "results are exhausted",
)
- self.assertTrue(cur.rowcount in (-1,6))
+ self.assertTrue(cur.rowcount in (-1, 6))
# Same as above, using cursor.arraysize
- cur.arraysize=4
- cur.execute('select name from %sbooze' % self.table_prefix)
- r = cur.fetchmany() # Should get 4 rows
- self.assertEqual(len(r),4,
- 'cursor.arraysize not being honoured by fetchmany'
- )
- r = cur.fetchmany() # Should get 2 more
- self.assertEqual(len(r),2)
- r = cur.fetchmany() # Should be an empty sequence
- self.assertEqual(len(r),0)
- self.assertTrue(cur.rowcount in (-1,6))
-
- cur.arraysize=6
- cur.execute('select name from %sbooze' % self.table_prefix)
- rows = cur.fetchmany() # Should get all rows
- self.assertTrue(cur.rowcount in (-1,6))
- self.assertEqual(len(rows),6)
- self.assertEqual(len(rows),6)
+ cur.arraysize = 4
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ r = cur.fetchmany() # Should get 4 rows
+ self.assertEqual(
+ len(r), 4, "cursor.arraysize not being honoured by fetchmany"
+ )
+ r = cur.fetchmany() # Should get 2 more
+ self.assertEqual(len(r), 2)
+ r = cur.fetchmany() # Should be an empty sequence
+ self.assertEqual(len(r), 0)
+ self.assertTrue(cur.rowcount in (-1, 6))
+
+ cur.arraysize = 6
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ rows = cur.fetchmany() # Should get all rows
+ self.assertTrue(cur.rowcount in (-1, 6))
+ self.assertEqual(len(rows), 6)
+ self.assertEqual(len(rows), 6)
rows = [r[0] for r in rows]
rows.sort()
# Make sure we get the right data back out
- for i in range(0,6):
- self.assertEqual(rows[i],self.samples[i],
- 'incorrect data retrieved by cursor.fetchmany'
- )
-
- rows = cur.fetchmany() # Should return an empty list
- self.assertEqual(len(rows),0,
- 'cursor.fetchmany should return an empty sequence if '
- 'called after the whole result set has been fetched'
+ for i in range(0, 6):
+ self.assertEqual(
+ rows[i],
+ self.samples[i],
+ "incorrect data retrieved by cursor.fetchmany",
)
- self.assertTrue(cur.rowcount in (-1,6))
+
+ rows = cur.fetchmany() # Should return an empty list
+ self.assertEqual(
+ len(rows),
+ 0,
+ "cursor.fetchmany should return an empty sequence if "
+ "called after the whole result set has been fetched",
+ )
+ self.assertTrue(cur.rowcount in (-1, 6))
self.executeDDL2(cur)
- cur.execute('select name from %sbarflys' % self.table_prefix)
- r = cur.fetchmany() # Should get empty sequence
- self.assertEqual(len(r),0,
- 'cursor.fetchmany should return an empty sequence if '
- 'query retrieved no rows'
- )
- self.assertTrue(cur.rowcount in (-1,0))
+ cur.execute("select name from %sbarflys" % self.table_prefix)
+ r = cur.fetchmany() # Should get empty sequence
+ self.assertEqual(
+ len(r),
+ 0,
+ "cursor.fetchmany should return an empty sequence if "
+ "query retrieved no rows",
+ )
+ self.assertTrue(cur.rowcount in (-1, 0))
finally:
con.close()
@@ -637,36 +636,41 @@ def test_fetchall(self):
# cursor.fetchall should raise an Error if called
# after executing a a statement that cannot return rows
- self.assertRaises(self.driver.Error,cur.fetchall)
+ self.assertRaises(self.driver.Error, cur.fetchall)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
rows = cur.fetchall()
- self.assertTrue(cur.rowcount in (-1,len(self.samples)))
- self.assertEqual(len(rows),len(self.samples),
- 'cursor.fetchall did not retrieve all rows'
- )
+ self.assertTrue(cur.rowcount in (-1, len(self.samples)))
+ self.assertEqual(
+ len(rows),
+ len(self.samples),
+ "cursor.fetchall did not retrieve all rows",
+ )
rows = [r[0] for r in rows]
rows.sort()
- for i in range(0,len(self.samples)):
- self.assertEqual(rows[i],self.samples[i],
- 'cursor.fetchall retrieved incorrect rows'
+ for i in range(0, len(self.samples)):
+ self.assertEqual(
+ rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows"
)
rows = cur.fetchall()
self.assertEqual(
- len(rows),0,
- 'cursor.fetchall should return an empty list if called '
- 'after the whole result set has been fetched'
- )
- self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+ len(rows),
+ 0,
+ "cursor.fetchall should return an empty list if called "
+ "after the whole result set has been fetched",
+ )
+ self.assertTrue(cur.rowcount in (-1, len(self.samples)))
self.executeDDL2(cur)
- cur.execute('select name from %sbarflys' % self.table_prefix)
+ cur.execute("select name from %sbarflys" % self.table_prefix)
rows = cur.fetchall()
- self.assertTrue(cur.rowcount in (-1,0))
- self.assertEqual(len(rows),0,
- 'cursor.fetchall should return an empty list if '
- 'a select query returns no rows'
- )
+ self.assertTrue(cur.rowcount in (-1, 0))
+ self.assertEqual(
+ len(rows),
+ 0,
+ "cursor.fetchall should return an empty list if "
+ "a select query returns no rows",
+ )
finally:
con.close()
@@ -679,74 +683,74 @@ def test_mixedfetch(self):
for sql in self._populate():
cur.execute(sql)
- cur.execute('select name from %sbooze' % self.table_prefix)
- rows1 = cur.fetchone()
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ rows1 = cur.fetchone()
rows23 = cur.fetchmany(2)
- rows4 = cur.fetchone()
+ rows4 = cur.fetchone()
rows56 = cur.fetchall()
- self.assertTrue(cur.rowcount in (-1,6))
- self.assertEqual(len(rows23),2,
- 'fetchmany returned incorrect number of rows'
- )
- self.assertEqual(len(rows56),2,
- 'fetchall returned incorrect number of rows'
- )
+ self.assertTrue(cur.rowcount in (-1, 6))
+ self.assertEqual(
+ len(rows23), 2, "fetchmany returned incorrect number of rows"
+ )
+ self.assertEqual(
+ len(rows56), 2, "fetchall returned incorrect number of rows"
+ )
rows = [rows1[0]]
- rows.extend([rows23[0][0],rows23[1][0]])
+ rows.extend([rows23[0][0], rows23[1][0]])
rows.append(rows4[0])
- rows.extend([rows56[0][0],rows56[1][0]])
+ rows.extend([rows56[0][0], rows56[1][0]])
rows.sort()
- for i in range(0,len(self.samples)):
- self.assertEqual(rows[i],self.samples[i],
- 'incorrect data retrieved or inserted'
- )
+ for i in range(0, len(self.samples)):
+ self.assertEqual(
+ rows[i], self.samples[i], "incorrect data retrieved or inserted"
+ )
finally:
con.close()
- def help_nextset_setUp(self,cur):
- ''' Should create a procedure called deleteme
- that returns two result sets, first the
- number of rows in booze then "name from booze"
- '''
- raise NotImplementedError('Helper not implemented')
- #sql="""
+ def help_nextset_setUp(self, cur):
+ """Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ """
+ raise NotImplementedError("Helper not implemented")
+ # sql="""
# create procedure deleteme as
# begin
# select count(*) from booze
# select name from booze
# end
- #"""
- #cur.execute(sql)
+ # """
+ # cur.execute(sql)
- def help_nextset_tearDown(self,cur):
- 'If cleaning up is needed after nextSetTest'
- raise NotImplementedError('Helper not implemented')
- #cur.execute("drop procedure deleteme")
+ def help_nextset_tearDown(self, cur):
+ "If cleaning up is needed after nextSetTest"
+ raise NotImplementedError("Helper not implemented")
+ # cur.execute("drop procedure deleteme")
def test_nextset(self):
con = self._connect()
try:
cur = con.cursor()
- if not hasattr(cur,'nextset'):
+ if not hasattr(cur, "nextset"):
return
try:
self.executeDDL1(cur)
- sql=self._populate()
+ sql = self._populate()
for sql in self._populate():
cur.execute(sql)
self.help_nextset_setUp(cur)
- cur.callproc('deleteme')
- numberofrows=cur.fetchone()
- assert numberofrows[0]== len(self.samples)
+ cur.callproc("deleteme")
+ numberofrows = cur.fetchone()
+ assert numberofrows[0] == len(self.samples)
assert cur.nextset()
- names=cur.fetchall()
+ names = cur.fetchall()
assert len(names) == len(self.samples)
- s=cur.nextset()
- assert s == None,'No more return sets, should return None'
+ s = cur.nextset()
+ assert s == None, "No more return sets, should return None"
finally:
self.help_nextset_tearDown(cur)
@@ -754,16 +758,16 @@ def test_nextset(self):
con.close()
def test_nextset(self):
- raise NotImplementedError('Drivers need to override this test')
+ raise NotImplementedError("Drivers need to override this test")
def test_arraysize(self):
# Not much here - rest of the tests for this are in test_fetchmany
con = self._connect()
try:
cur = con.cursor()
- self.assertTrue(hasattr(cur,'arraysize'),
- 'cursor.arraysize must be defined'
- )
+ self.assertTrue(
+ hasattr(cur, "arraysize"), "cursor.arraysize must be defined"
+ )
finally:
con.close()
@@ -771,8 +775,8 @@ def test_setinputsizes(self):
con = self._connect()
try:
cur = con.cursor()
- cur.setinputsizes( (25,) )
- self._paraminsert(cur) # Make sure cursor still works
+ cur.setinputsizes((25,))
+ self._paraminsert(cur) # Make sure cursor still works
finally:
con.close()
@@ -782,74 +786,68 @@ def test_setoutputsize_basic(self):
try:
cur = con.cursor()
cur.setoutputsize(1000)
- cur.setoutputsize(2000,0)
- self._paraminsert(cur) # Make sure the cursor still works
+ cur.setoutputsize(2000, 0)
+ self._paraminsert(cur) # Make sure the cursor still works
finally:
con.close()
def test_setoutputsize(self):
- # Real test for setoutputsize is driver dependant
- raise NotImplementedError('Driver need to override this test')
+ # Real test for setoutputsize is driver dependent
+ raise NotImplementedError("Driver need to override this test")
def test_None(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
- cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("insert into %sbooze values (NULL)" % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchall()
- self.assertEqual(len(r),1)
- self.assertEqual(len(r[0]),1)
- self.assertEqual(r[0][0],None,'NULL value not returned as None')
+ self.assertEqual(len(r), 1)
+ self.assertEqual(len(r[0]), 1)
+ self.assertEqual(r[0][0], None, "NULL value not returned as None")
finally:
con.close()
def test_Date(self):
- d1 = self.driver.Date(2002,12,25)
- d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
+ self.driver.Date(2002, 12, 25)
+ self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(d1),str(d2))
def test_Time(self):
- t1 = self.driver.Time(13,45,30)
- t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
+ self.driver.Time(13, 45, 30)
+ self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Timestamp(self):
- t1 = self.driver.Timestamp(2002,12,25,13,45,30)
- t2 = self.driver.TimestampFromTicks(
- time.mktime((2002,12,25,13,45,30,0,0,0))
- )
+ self.driver.Timestamp(2002, 12, 25, 13, 45, 30)
+ self.driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Binary(self):
- b = self.driver.Binary(b'Something')
- b = self.driver.Binary(b'')
+ self.driver.Binary(b"Something")
+ self.driver.Binary(b"")
def test_STRING(self):
- self.assertTrue(hasattr(self.driver,'STRING'),
- 'module.STRING must be defined'
- )
+ self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined")
def test_BINARY(self):
- self.assertTrue(hasattr(self.driver,'BINARY'),
- 'module.BINARY must be defined.'
- )
+ self.assertTrue(
+ hasattr(self.driver, "BINARY"), "module.BINARY must be defined."
+ )
def test_NUMBER(self):
- self.assertTrue(hasattr(self.driver,'NUMBER'),
- 'module.NUMBER must be defined.'
- )
+ self.assertTrue(
+ hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined."
+ )
def test_DATETIME(self):
- self.assertTrue(hasattr(self.driver,'DATETIME'),
- 'module.DATETIME must be defined.'
- )
+ self.assertTrue(
+ hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined."
+ )
def test_ROWID(self):
- self.assertTrue(hasattr(self.driver,'ROWID'),
- 'module.ROWID must be defined.'
- )
+ self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.")
diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
index 0fc5e8316..6a2894a5a 100644
--- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
+++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
@@ -1,22 +1,24 @@
from . import capabilities
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
import pymysql
from pymysql.tests import base
import warnings
-warnings.filterwarnings('error')
+warnings.filterwarnings("error")
-class test_MySQLdb(capabilities.DatabaseTest):
+class test_MySQLdb(capabilities.DatabaseTest):
db_module = pymysql
connect_args = ()
connect_kwargs = base.PyMySQLTestCase.databases[0].copy()
- connect_kwargs.update(dict(read_default_file='~/.my.cnf',
- use_unicode=True, binary_prefix=True,
- charset='utf8mb4', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
+ connect_kwargs.update(
+ dict(
+ read_default_file="~/.my.cnf",
+ use_unicode=True,
+ binary_prefix=True,
+ charset="utf8mb4",
+ sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL",
+ )
+ )
leak_test = False
@@ -25,64 +27,70 @@ def quote_identifier(self, ident):
def test_TIME(self):
from datetime import timedelta
- def generator(row,col):
- return timedelta(0, row*8000)
- self.check_data_integrity(
- ('col1 TIME',),
- generator)
+
+ def generator(row, col):
+ return timedelta(0, row * 8000)
+
+ self.check_data_integrity(("col1 TIME",), generator)
def test_TINYINT(self):
# Number data
- def generator(row,col):
- v = (row*row) % 256
+ def generator(row, col):
+ v = (row * row) % 256
if v > 127:
- v = v-256
+ v = v - 256
return v
- self.check_data_integrity(
- ('col1 TINYINT',),
- generator)
+
+ self.check_data_integrity(("col1 TINYINT",), generator)
def test_stored_procedures(self):
db = self.connection
c = self.cursor
try:
- self.create_table(('pos INT', 'tree CHAR(20)'))
- c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
- list(enumerate('ash birch cedar larch pine'.split())))
+ self.create_table(("pos INT", "tree CHAR(20)"))
+ c.executemany(
+ "INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
+ list(enumerate("ash birch cedar larch pine".split())),
+ )
db.commit()
- c.execute("""
+ c.execute(
+ """
CREATE PROCEDURE test_sp(IN t VARCHAR(255))
BEGIN
SELECT pos FROM %s WHERE tree = t;
END
- """ % self.table)
+ """
+ % self.table
+ )
db.commit()
- c.callproc('test_sp', ('larch',))
+ c.callproc("test_sp", ("larch",))
rows = c.fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], 3)
c.nextset()
finally:
c.execute("DROP PROCEDURE IF EXISTS test_sp")
- c.execute('drop table %s' % (self.table))
+ c.execute("drop table %s" % (self.table))
def test_small_CHAR(self):
# Character data
- def generator(row,col):
- i = ((row+1)*(col+1)+62)%256
- if i == 62: return ''
- if i == 63: return None
+ def generator(row, col):
+ i = ((row + 1) * (col + 1) + 62) % 256
+ if i == 62:
+ return ""
+ if i == 63:
+ return None
return chr(i)
- self.check_data_integrity(
- ('col1 char(1)','col2 char(1)'),
- generator)
+
+ self.check_data_integrity(("col1 char(1)", "col2 char(1)"), generator)
def test_bug_2671682(self):
from pymysql.constants import ER
+
try:
- self.cursor.execute("describe some_non_existent_table");
+ self.cursor.execute("describe some_non_existent_table")
except self.connection.ProgrammingError as msg:
self.assertEqual(msg.args[0], ER.NO_SUCH_TABLE)
@@ -93,7 +101,7 @@ def test_literal_int(self):
self.assertTrue("2" == self.connection.literal(2))
def test_literal_float(self):
- self.assertTrue("3.1415" == self.connection.literal(3.1415))
+ self.assertEqual("3.1415e0", self.connection.literal(3.1415))
def test_literal_string(self):
self.assertTrue("'foo'" == self.connection.literal("foo"))
diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
index a26691626..5c34d40d1 100644
--- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
+++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
@@ -2,23 +2,24 @@
import pymysql
from pymysql.tests import base
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
-
class test_MySQLdb(dbapi20.DatabaseAPI20Test):
driver = pymysql
connect_args = ()
connect_kw_args = base.PyMySQLTestCase.databases[0].copy()
- connect_kw_args.update(dict(read_default_file='~/.my.cnf',
- charset='utf8',
- sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
+ connect_kw_args.update(
+ dict(
+ read_default_file="~/.my.cnf",
+ charset="utf8",
+ sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL",
+ )
+ )
+
+ def test_setoutputsize(self):
+ pass
- def test_setoutputsize(self): pass
- def test_setoutputsize_basic(self): pass
- def test_nextset(self): pass
+ def test_setoutputsize_basic(self):
+ pass
"""The tests on fetchone and fetchall and rowcount bogusly
test for an exception if the statement cannot return a
@@ -40,36 +41,41 @@ def test_fetchall(self):
# cursor.fetchall should raise an Error if called
# after executing a a statement that cannot return rows
-## self.assertRaises(self.driver.Error,cur.fetchall)
+ ## self.assertRaises(self.driver.Error,cur.fetchall)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
rows = cur.fetchall()
- self.assertTrue(cur.rowcount in (-1,len(self.samples)))
- self.assertEqual(len(rows),len(self.samples),
- 'cursor.fetchall did not retrieve all rows'
- )
+ self.assertTrue(cur.rowcount in (-1, len(self.samples)))
+ self.assertEqual(
+ len(rows),
+ len(self.samples),
+ "cursor.fetchall did not retrieve all rows",
+ )
rows = [r[0] for r in rows]
rows.sort()
- for i in range(0,len(self.samples)):
- self.assertEqual(rows[i],self.samples[i],
- 'cursor.fetchall retrieved incorrect rows'
+ for i in range(0, len(self.samples)):
+ self.assertEqual(
+ rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows"
)
rows = cur.fetchall()
self.assertEqual(
- len(rows),0,
- 'cursor.fetchall should return an empty list if called '
- 'after the whole result set has been fetched'
- )
- self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+ len(rows),
+ 0,
+ "cursor.fetchall should return an empty list if called "
+ "after the whole result set has been fetched",
+ )
+ self.assertTrue(cur.rowcount in (-1, len(self.samples)))
self.executeDDL2(cur)
- cur.execute('select name from %sbarflys' % self.table_prefix)
+ cur.execute("select name from %sbarflys" % self.table_prefix)
rows = cur.fetchall()
- self.assertTrue(cur.rowcount in (-1,0))
- self.assertEqual(len(rows),0,
- 'cursor.fetchall should return an empty list if '
- 'a select query returns no rows'
- )
+ self.assertTrue(cur.rowcount in (-1, 0))
+ self.assertEqual(
+ len(rows),
+ 0,
+ "cursor.fetchall should return an empty list if "
+ "a select query returns no rows",
+ )
finally:
con.close()
@@ -81,39 +87,40 @@ def test_fetchone(self):
# cursor.fetchone should raise an Error if called before
# executing a select-type query
- self.assertRaises(self.driver.Error,cur.fetchone)
+ self.assertRaises(self.driver.Error, cur.fetchone)
# cursor.fetchone should raise an Error if called after
- # executing a query that cannnot return rows
+ # executing a query that cannot return rows
self.executeDDL1(cur)
-## self.assertRaises(self.driver.Error,cur.fetchone)
+ ## self.assertRaises(self.driver.Error,cur.fetchone)
- cur.execute('select name from %sbooze' % self.table_prefix)
- self.assertEqual(cur.fetchone(),None,
- 'cursor.fetchone should return None if a query retrieves '
- 'no rows'
- )
- self.assertTrue(cur.rowcount in (-1,0))
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertEqual(
+ cur.fetchone(),
+ None,
+ "cursor.fetchone should return None if a query retrieves no rows",
+ )
+ self.assertTrue(cur.rowcount in (-1, 0))
# cursor.fetchone should raise an Error if called after
- # executing a query that cannnot return rows
- cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
- self.table_prefix
- ))
-## self.assertRaises(self.driver.Error,cur.fetchone)
+ # executing a query that cannot return rows
+ cur.execute(
+ "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
+ )
+ ## self.assertRaises(self.driver.Error,cur.fetchone)
- cur.execute('select name from %sbooze' % self.table_prefix)
+ cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchone()
- self.assertEqual(len(r),1,
- 'cursor.fetchone should have retrieved a single row'
- )
- self.assertEqual(r[0],'Victoria Bitter',
- 'cursor.fetchone retrieved incorrect data'
- )
-## self.assertEqual(cur.fetchone(),None,
-## 'cursor.fetchone should return None if no more rows available'
-## )
- self.assertTrue(cur.rowcount in (-1,1))
+ self.assertEqual(
+ len(r), 1, "cursor.fetchone should have retrieved a single row"
+ )
+ self.assertEqual(
+ r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data"
+ )
+ ## self.assertEqual(cur.fetchone(),None,
+ ## 'cursor.fetchone should return None if no more rows available'
+ ## )
+ self.assertTrue(cur.rowcount in (-1, 1))
finally:
con.close()
@@ -123,81 +130,84 @@ def test_rowcount(self):
try:
cur = con.cursor()
self.executeDDL1(cur)
-## self.assertEqual(cur.rowcount,-1,
-## 'cursor.rowcount should be -1 after executing no-result '
-## 'statements'
-## )
- cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
- self.table_prefix
- ))
-## self.assertTrue(cur.rowcount in (-1,1),
-## 'cursor.rowcount should == number or rows inserted, or '
-## 'set to -1 after executing an insert statement'
-## )
+ ## self.assertEqual(cur.rowcount,-1,
+ ## 'cursor.rowcount should be -1 after executing no-result '
+ ## 'statements'
+ ## )
+ cur.execute(
+ "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
+ )
+ ## self.assertTrue(cur.rowcount in (-1,1),
+ ## 'cursor.rowcount should == number or rows inserted, or '
+ ## 'set to -1 after executing an insert statement'
+ ## )
cur.execute("select name from %sbooze" % self.table_prefix)
- self.assertTrue(cur.rowcount in (-1,1),
- 'cursor.rowcount should == number of rows returned, or '
- 'set to -1 after executing a select statement'
- )
+ self.assertTrue(
+ cur.rowcount in (-1, 1),
+ "cursor.rowcount should == number of rows returned, or "
+ "set to -1 after executing a select statement",
+ )
self.executeDDL2(cur)
-## self.assertEqual(cur.rowcount,-1,
-## 'cursor.rowcount not being reset to -1 after executing '
-## 'no-result statements'
-## )
+ ## self.assertEqual(cur.rowcount,-1,
+ ## 'cursor.rowcount not being reset to -1 after executing '
+ ## 'no-result statements'
+ ## )
finally:
con.close()
def test_callproc(self):
- pass # performed in test_MySQL_capabilities
-
- def help_nextset_setUp(self,cur):
- ''' Should create a procedure called deleteme
- that returns two result sets, first the
- number of rows in booze then "name from booze"
- '''
- sql="""
+ pass # performed in test_MySQL_capabilities
+
+ def help_nextset_setUp(self, cur):
+ """Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ """
+ sql = """
create procedure deleteme()
begin
select count(*) from %(tp)sbooze;
select name from %(tp)sbooze;
end
- """ % dict(tp=self.table_prefix)
+ """ % dict(
+ tp=self.table_prefix
+ )
cur.execute(sql)
- def help_nextset_tearDown(self,cur):
- 'If cleaning up is needed after nextSetTest'
+ def help_nextset_tearDown(self, cur):
+ "If cleaning up is needed after nextSetTest"
cur.execute("drop procedure deleteme")
def test_nextset(self):
- from warnings import warn
con = self._connect()
try:
cur = con.cursor()
- if not hasattr(cur,'nextset'):
+ if not hasattr(cur, "nextset"):
return
try:
self.executeDDL1(cur)
- sql=self._populate()
+ sql = self._populate()
for sql in self._populate():
cur.execute(sql)
self.help_nextset_setUp(cur)
- cur.callproc('deleteme')
- numberofrows=cur.fetchone()
- assert numberofrows[0]== len(self.samples)
+ cur.callproc("deleteme")
+ numberofrows = cur.fetchone()
+ assert numberofrows[0] == len(self.samples)
assert cur.nextset()
- names=cur.fetchall()
+ names = cur.fetchall()
assert len(names) == len(self.samples)
- s=cur.nextset()
+ s = cur.nextset()
if s:
empty = cur.fetchall()
- self.assertEqual(len(empty), 0,
- "non-empty result set after other result sets")
- #warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
+ self.assertEqual(
+ len(empty), 0, "non-empty result set after other result sets"
+ )
+ # warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
# Warning)
- #assert s == None,'No more return sets, should return None'
+ # assert s == None,'No more return sets, should return None'
finally:
self.help_nextset_tearDown(cur)
diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
index 17fc2cde5..1545fbb5e 100644
--- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
+++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
@@ -1,17 +1,10 @@
-import sys
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
+import unittest
import pymysql
+
_mysql = pymysql
from pymysql.constants import FIELD_TYPE
from pymysql.tests import base
-from pymysql._compat import PY2, long_type
-
-if not PY2:
- basestring = str
class TestDBAPISet(unittest.TestCase):
@@ -33,17 +26,17 @@ class CoreModule(unittest.TestCase):
def test_NULL(self):
"""Should have a NULL constant."""
- self.assertEqual(_mysql.NULL, 'NULL')
+ self.assertEqual(_mysql.NULL, "NULL")
def test_version(self):
"""Version information sanity."""
- self.assertTrue(isinstance(_mysql.__version__, basestring))
+ self.assertTrue(isinstance(_mysql.__version__, str))
self.assertTrue(isinstance(_mysql.version_info, tuple))
self.assertEqual(len(_mysql.version_info), 5)
def test_client_info(self):
- self.assertTrue(isinstance(_mysql.get_client_info(), basestring))
+ self.assertTrue(isinstance(_mysql.get_client_info(), str))
def test_thread_safe(self):
self.assertTrue(isinstance(_mysql.thread_safe(), int))
@@ -62,40 +55,45 @@ def tearDown(self):
def test_thread_id(self):
tid = self.conn.thread_id()
- self.assertTrue(isinstance(tid, (int, long_type)),
- "thread_id didn't return an integral value.")
+ self.assertTrue(
+ isinstance(tid, int), "thread_id didn't return an integral value."
+ )
- self.assertRaises(TypeError, self.conn.thread_id, ('evil',),
- "thread_id shouldn't accept arguments.")
+ self.assertRaises(
+ TypeError,
+ self.conn.thread_id,
+ ("evil",),
+ "thread_id shouldn't accept arguments.",
+ )
def test_affected_rows(self):
- self.assertEqual(self.conn.affected_rows(), 0,
- "Should return 0 before we do anything.")
+ self.assertEqual(
+ self.conn.affected_rows(), 0, "Should return 0 before we do anything."
+ )
-
- #def test_debug(self):
- ## FIXME Only actually tests if you lack SUPER
- #self.assertRaises(pymysql.OperationalError,
- #self.conn.dump_debug_info)
+ # def test_debug(self):
+ ## FIXME Only actually tests if you lack SUPER
+ # self.assertRaises(pymysql.OperationalError,
+ # self.conn.dump_debug_info)
def test_charset_name(self):
- self.assertTrue(isinstance(self.conn.character_set_name(), basestring),
- "Should return a string.")
+ self.assertTrue(
+ isinstance(self.conn.character_set_name(), str), "Should return a string."
+ )
def test_host_info(self):
- assert isinstance(self.conn.get_host_info(), basestring), "should return a string"
+ assert isinstance(self.conn.get_host_info(), str), "should return a string"
def test_proto_info(self):
- self.assertTrue(isinstance(self.conn.get_proto_info(), int),
- "Should return an int.")
+ self.assertTrue(
+ isinstance(self.conn.get_proto_info(), int), "Should return an int."
+ )
def test_server_info(self):
- if sys.version_info[0] == 2:
- self.assertTrue(isinstance(self.conn.get_server_info(), basestring),
- "Should return an str.")
- else:
- self.assertTrue(isinstance(self.conn.get_server_info(), basestring),
- "Should return an str.")
+ self.assertTrue(
+ isinstance(self.conn.get_server_info(), str), "Should return an str."
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/pymysql/util.py b/pymysql/util.py
deleted file mode 100644
index 3e82ac7b5..000000000
--- a/pymysql/util.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import struct
-
-
-def byte2int(b):
- if isinstance(b, int):
- return b
- else:
- return struct.unpack("!B", b)[0]
-
-
-def int2byte(i):
- return struct.pack("!B", i)
-
-
-def join_bytes(bs):
- if len(bs) == 0:
- return ""
- else:
- rv = bs[0]
- for b in bs[1:]:
- rv += b
- return rv
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 000000000..8cd9ddb45
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,66 @@
+[project]
+name = "PyMySQL"
+description = "Pure Python MySQL Driver"
+authors = [
+ {name = "Inada Naoki", email = "songofacandy@gmail.com"},
+ {name = "Yutaka Matsubara", email = "yutaka.matsubara@gmail.com"}
+]
+dependencies = []
+
+requires-python = ">=3.7"
+readme = "README.md"
+license = {text = "MIT License"}
+keywords = ["MySQL"]
+classifiers = [
+ "Development Status :: 5 - Production/Stable",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Programming Language :: Python :: Implementation :: PyPy",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Topic :: Database",
+]
+dynamic = ["version"]
+
+[project.optional-dependencies]
+"rsa" = [
+ "cryptography"
+]
+"ed25519" = [
+ "PyNaCl>=1.4.0"
+]
+
+[project.urls]
+"Project" = "https://github.com/PyMySQL/PyMySQL"
+"Documentation" = "https://pymysql.readthedocs.io/"
+
+[build-system]
+requires = ["setuptools>=61"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+namespaces = false
+include = ["pymysql*"]
+exclude = ["tests*", "pymysql.tests*"]
+
+[tool.setuptools.dynamic]
+version = {attr = "pymysql.VERSION_STRING"}
+
+[tool.ruff]
+exclude = [
+ "pymysql/tests/thirdparty",
+]
+
+[tool.ruff.lint]
+ignore = ["E721"]
+
+[tool.pdm.dev-dependencies]
+dev = [
+ "pytest-cov>=4.0.0",
+]
diff --git a/renovate.json b/renovate.json
new file mode 100644
index 000000000..09e16da6b
--- /dev/null
+++ b/renovate.json
@@ -0,0 +1,7 @@
+{
+ "$schema": "https://docs.renovatebot.com/renovate-schema.json",
+ "extends": [
+ "config:base"
+ ],
+ "dependencyDashboard": false
+}
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 000000000..140d37067
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,4 @@
+cryptography
+PyNaCl>=1.4.0
+pytest
+pytest-cov
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 70f051613..000000000
--- a/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-cryptography
-
diff --git a/runtests.py b/runtests.py
deleted file mode 100755
index ea3d9e8dd..000000000
--- a/runtests.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#!/usr/bin/env python
-import unittest2
-
-from pymysql._compat import PYPY, JYTHON, IRONPYTHON
-
-#import pymysql
-#pymysql.connections.DEBUG = True
-#pymysql._auth.DEBUG = True
-
-if not (PYPY or JYTHON or IRONPYTHON):
- import atexit
- import gc
- gc.set_debug(gc.DEBUG_UNCOLLECTABLE)
-
- @atexit.register
- def report_uncollectable():
- import gc
- if not gc.garbage:
- print("No garbages!")
- return
- print('uncollectable objects')
- for obj in gc.garbage:
- print(obj)
- if hasattr(obj, '__dict__'):
- print(obj.__dict__)
- for ref in gc.get_referrers(obj):
- print("referrer:", ref)
- print('---')
-
-import pymysql.tests
-unittest2.main(pymysql.tests, verbosity=2)
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index a26a846b9..000000000
--- a/setup.cfg
+++ /dev/null
@@ -1,17 +0,0 @@
-[flake8]
-ignore = E226,E301,E701
-exclude = tests,build
-max-line-length = 119
-
-[bdist_wheel]
-universal = 1
-
-[metadata]
-license = "MIT"
-license_file = LICENSE
-
-author=yutaka.matsubara
-author_email=yutaka.matsubara@gmail.com
-
-maintainer=INADA Naoki
-maintainer_email=songofacandy@gmail.com
diff --git a/setup.py b/setup.py
deleted file mode 100755
index 14650d1c7..000000000
--- a/setup.py
+++ /dev/null
@@ -1,39 +0,0 @@
-#!/usr/bin/env python
-import io
-from setuptools import setup, find_packages
-
-version = "0.9.2"
-
-with io.open('./README.rst', encoding='utf-8') as f:
- readme = f.read()
-
-setup(
- name="PyMySQL",
- version=version,
- url='https://github.com/PyMySQL/PyMySQL/',
- project_urls={
- "Documentation": "https://pymysql.readthedocs.io/",
- },
- description='Pure Python MySQL Driver',
- long_description=readme,
- packages=find_packages(exclude=['tests*', 'pymysql.tests*']),
- install_requires=[
- "cryptography",
- ],
- classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7',
- 'Programming Language :: Python :: Implementation :: CPython',
- 'Programming Language :: Python :: Implementation :: PyPy',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: MIT License',
- 'Topic :: Database',
- ],
- keywords="MySQL",
-)
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 7d8573442..e5e2a64e5 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -10,7 +10,10 @@
port = 3306
ca = os.path.expanduser("~/ca.pem")
-ssl = {'ca': ca, 'check_hostname': False}
+ssl = {"ca": ca, "check_hostname": False}
+
+pass_sha256 = "pass_sha256_01234567890123456789"
+pass_caching_sha2 = "pass_caching_sha2_01234567890123456789"
def test_sha256_no_password():
@@ -24,12 +27,16 @@ def test_sha256_no_passowrd_ssl():
def test_sha256_password():
- con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=None)
+ con = pymysql.connect(
+ user="user_sha256", password=pass_sha256, host=host, port=port, ssl=None
+ )
con.close()
def test_sha256_password_ssl():
- con = pymysql.connect(user="user_sha256", password="pass_sha256", host=host, port=port, ssl=ssl)
+ con = pymysql.connect(
+ user="user_sha256", password=pass_sha256, host=host, port=port, ssl=ssl
+ )
con.close()
@@ -38,26 +45,50 @@ def test_caching_sha2_no_password():
con.close()
-def test_caching_sha2_no_password():
+def test_caching_sha2_no_password_ssl():
con = pymysql.connect(user="nopass_caching_sha2", host=host, port=port, ssl=ssl)
con.close()
def test_caching_sha2_password():
- con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None)
+ con = pymysql.connect(
+ user="user_caching_sha2",
+ password=pass_caching_sha2,
+ host=host,
+ port=port,
+ ssl=None,
+ )
con.close()
# Fast path of caching sha2
- con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None)
+ con = pymysql.connect(
+ user="user_caching_sha2",
+ password=pass_caching_sha2,
+ host=host,
+ port=port,
+ ssl=None,
+ )
con.query("FLUSH PRIVILEGES")
con.close()
def test_caching_sha2_password_ssl():
- con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=ssl)
+ con = pymysql.connect(
+ user="user_caching_sha2",
+ password=pass_caching_sha2,
+ host=host,
+ port=port,
+ ssl=ssl,
+ )
con.close()
# Fast path of caching sha2
- con = pymysql.connect(user="user_caching_sha2", password="pass_caching_sha2", host=host, port=port, ssl=None)
+ con = pymysql.connect(
+ user="user_caching_sha2",
+ password=pass_caching_sha2,
+ host=host,
+ port=port,
+ ssl=None,
+ )
con.query("FLUSH PRIVILEGES")
con.close()
diff --git a/tox.ini b/tox.ini
deleted file mode 100644
index a50364c9c..000000000
--- a/tox.ini
+++ /dev/null
@@ -1,10 +0,0 @@
-[tox]
-envlist = py26,py27,py33,py34,pypy,pypy3
-
-[testenv]
-commands = coverage run ./runtests.py
-deps = unittest2
- coverage
-passenv = USER
- PASSWORD
- PAMSERVICE